Skip to content

Commit

Permalink
Mirt redux (#72)
Browse files Browse the repository at this point in the history
* Add postposterior covariance matrix based mirt rules

* Qualify usage of even_grid in dt test

* Formatting of ability_estimator

* Add todo note to comparison.jl

* Refactor dispersion around ScalarizedStateCriteron

* Apply formatting
  • Loading branch information
frankier authored Oct 30, 2024
1 parent ec41675 commit 385052f
Show file tree
Hide file tree
Showing 10 changed files with 348 additions and 108 deletions.
1 change: 1 addition & 0 deletions src/Comparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ function run_comparison(comparison::CatComparisonConfig{ReplayResponsesExecution
items_answered = items_answered
)
if :after_item_criteria in comparison.phases
# TOOD: Combine with next_item if possible and requested?
timed_item_criteria = @timed Stateful.item_criteria(cat)
measure_all(
comparison,
Expand Down
2 changes: 1 addition & 1 deletion src/aggregators/Aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using PsychometricsBazaarBase.ConfigTools
using PsychometricsBazaarBase.Integrators
using PsychometricsBazaarBase: Integrators
using PsychometricsBazaarBase.Optimizers
using PsychometricsBazaarBase.ConstDistributions: std_normal
using PsychometricsBazaarBase.ConstDistributions: std_normal, std_mv_normal

import FittedItemBanks
import PsychometricsBazaarBase.IntegralCoeffs
Expand Down
53 changes: 52 additions & 1 deletion src/aggregators/ability_estimator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ struct PriorAbilityEstimator{PriorT <: Distribution} <: DistributionAbilityEstim
prior::PriorT
end

PriorAbilityEstimator() = PriorAbilityEstimator(std_normal)
function PriorAbilityEstimator(; ncomp = 0)
if ncomp == 0
return PriorAbilityEstimator(std_normal)
else
return PriorAbilityEstimator(std_mv_normal(ncomp))
end
end

function pdf(est::PriorAbilityEstimator,
tracked_responses::TrackedResponses)
Expand Down Expand Up @@ -73,6 +79,21 @@ function mean_1d(integrator::AbilityIntegrator,
denom)
end

function mean(
integrator::AbilityIntegrator,
est::DistributionAbilityEstimator,
tracked_responses::TrackedResponses,
denom = normdenom(integrator, est, tracked_responses)
)
n = domdims(tracked_responses.item_bank)
expectation(IntegralCoeffs.id,
n,
integrator,
est,
tracked_responses,
denom)
end

function variance_given_mean(integrator::AbilityIntegrator,
est::DistributionAbilityEstimator,
tracked_responses::TrackedResponses,
Expand All @@ -97,6 +118,36 @@ function variance(integrator::AbilityIntegrator,
denom)
end

function covariance_matrix_given_mean(
integrator::AbilityIntegrator,
est::DistributionAbilityEstimator,
tracked_responses::TrackedResponses,
mean,
denom = normdenom(integrator, est, tracked_responses)
)
n = domdims(tracked_responses.item_bank)
expectation(IntegralCoeffs.OuterProdDev(mean),
n,
integrator,
est,
tracked_responses,
denom)
end

function covariance_matrix(
integrator::AbilityIntegrator,
est::DistributionAbilityEstimator,
tracked_responses::TrackedResponses,
denom = normdenom(integrator, est, tracked_responses))
covariance_matrix_given_mean(
integrator,
est,
tracked_responses,
mean(integrator, est, tracked_responses, denom),
denom
)
end

struct ModeAbilityEstimator{
DistEst <: DistributionAbilityEstimator,
OptimizerT <: AbilityOptimizer
Expand Down
9 changes: 8 additions & 1 deletion src/next_item_rules/NextItemRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,25 @@ import PsychometricsBazaarBase.IntegralCoeffs
using FittedItemBanks
using FittedItemBanks: item_params
using ..Aggregators
using ..Aggregators: covariance_matrix

using Distributions, Base.Threads, Base.Order, StaticArrays
using ConstructionBase: constructorof
import ForwardDiff

export ExpectationBasedItemCriterion, AbilityVarianceStateCriterion, init_thread
export NextItemRule, ItemStrategyNextItemRule
export UrryItemCriterion, InformationItemCriterion, DRuleItemCriterion, TRuleItemCriterion
export UrryItemCriterion, InformationItemCriterion
export RandomNextItemRule
export ExhaustiveSearch1Ply
export catr_next_item_aliases
export preallocate
export compute_criteria
export PointResponseExpectation, DistributionResponseExpectation
export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer
export AbilityCovarianceStateCriteria, StateCriteria, ItemCriteria
export InformationMatrixCriteria
export ScalarizedStateCriteron, ScalarizedItemCriteron

"""
$(TYPEDEF)
Expand Down Expand Up @@ -68,6 +73,7 @@ end

include("./random.jl")
include("./information.jl")
include("./information_special.jl")
include("./objective_function.jl")
include("./expectation.jl")

Expand Down Expand Up @@ -197,6 +203,7 @@ function compute_criteria(
compute_criteria(rule.criterion, responses, items)
end

include("./mirt.jl")
include("./aliases.jl")
include("./preallocate.jl")

Expand Down
8 changes: 7 additions & 1 deletion src/next_item_rules/aliases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ const mirtcat_next_item_aliases = Dict(
# 'MEPV' for minimum expected posterior variance
"MEPV" => _mirtcat_helper((bits, ability_estimator) -> ExpectationBasedItemCriterion(
ability_estimator,
AbilityVarianceStateCriterion(bits...)))
AbilityVarianceStateCriterion(bits...))),
"Drule" => _mirtcat_helper((bits, ability_estimator) -> ScalarizedItemCriteron(
InformationMatrixCriteria(ability_estimator),
DeterminantScalarizer())),
"Trule" => _mirtcat_helper((bits, ability_estimator) -> ScalarizedItemCriteron(
InformationMatrixCriteria(ability_estimator),
TraceScalarizer()))
)

# 'MLWI' for maximum likelihood weighted information
Expand Down
37 changes: 22 additions & 15 deletions src/next_item_rules/information.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@ using FittedItemBanks: CdfMirtItemBank,
using FittedItemBanks: inner_item_response, norm_abil, y_offset, irf_size
using StatsFuns: logaddexp

function log_resp_vec(ir::ItemResponse{<:TransferItemBank}, θ)
= norm_abil(ir, θ)
return SVector(
logccdf(ir.item_bank.distribution, nθ),
logcdf(ir.item_bank.distribution, nθ)
)
end

function log_resp(ir::ItemResponse{<:TransferItemBank}, resp, θ)
logcdf(ir.item_bank.distribution, norm_abil(ir, θ))
end

function log_resp_vec(ir::ItemResponse{<:CdfMirtItemBank}, θ)
= norm_abil(ir, θ)
SVector(logccdf(ir.item_bank.distribution, nθ),
Expand Down Expand Up @@ -52,26 +64,21 @@ function log_resp(ir::ItemResponse{<:AnySlipOrGuessItemBank}, val, θ)
log_transform_irf_y(ir, val, log_resp(inner_item_response(ir), val, θ))
end

# How does this compare with expected_item_information. Speeds/accuracies?
# TODO: Which response models is this valid for?
# TODO: Citation/source for this equation
# TODO: Do it in log space?
function item_information(ir::ItemResponse, θ)
# irθ_prime = ForwardDiff.derivative(ir, θ)
irθ_prime = ForwardDiff.derivative(x -> resp(ir, x), θ)
irθ = resp(ir, θ)
if irθ_prime == 0.0
return 0.0
else
return (irθ_prime * irθ_prime) / (irθ * (1 - irθ))
end
end

function vector_hessian(f, x, n)
out = ForwardDiff.jacobian(x -> ForwardDiff.jacobian(f, x), x)
return reshape(out, n, n, n)
end

function double_derivative(f, x)
ForwardDiff.derivative(x -> ForwardDiff.derivative(f, x), x)
end

function expected_item_information(ir::ItemResponse, θ::Float64)
exp_resp = resp_vec(ir, θ)
= double_derivative((θ -> log_resp_vec(ir, θ)), θ)
-sum(exp_resp .* d²)
end

# TODO: Unclear whether this should be implemented with ExpectationBasedItemCriterion
# TODO: This is not implementing DRule but postposterior DRule
function expected_item_information(ir::ItemResponse, θ::Vector{Float64})
Expand Down
65 changes: 65 additions & 0 deletions src/next_item_rules/information_special.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#=
This file contains some specialised ways to calculate information.
For some models analytical solutions are possible for information.
Most are simple applications of the chain rule
However, I haven't taken a systematic approach yet yet.
So these are just from equations in the literature.
There aren't really any type guards on these so its up to the caller to make sure they are using the right ones.
=#

function alt_expected_1d_item_information(ir::ItemResponse, θ)
"""
This is a special case of the expected_item_information function for
* 1-dimensional ability
* Dichotomous items
* It should be valid for at least up to the 3PL model, probably others too
TODO: citation
"""
# irθ_prime = ForwardDiff.derivative(ir, θ)
irθ_prime = ForwardDiff.derivative(x -> resp(ir, x), θ)
irθ = resp(ir, θ)
if irθ_prime == 0.0
return 0.0
else
return (irθ_prime * irθ_prime) / (irθ * (1 - irθ))
end
end

function alt_expected_mirt_item_information(ir::ItemResponse, θ)
"""
This is a special case of the expected_item_information function for
* Multidimensional
* Dichotomous items
* It should be valid for at least up to the 3PL model, probably others too
TODO: citation
"""
irθ_prime = ForwardDiff.gradient(x -> resp(ir, x), θ)
= resp(ir, θ)
= 1 -
(irθ_prime * irθ_prime') / (pθ * qθ)
end

function alt_expected_mirt_3pl_item_information(ir::ItemResponse, θ)
"""
This is a special case of the expected_item_information function for
* Multidimensional
* Dichotomous items
* 3PL model only
Mulder J, van der Linden WJ.
Multidimensional Adaptive Testing with Optimal Design Criteria for Item Selection.
Psychometrika. 2009 Jun;74(2):273-296. doi: 10.1007/s11336-008-9097-5.
Equation 4
"""
# XXX: Should avoid using item_params
params = item_params(ir.item_bank.discriminations, ir.index)
= resp(ir, θ)
= 1 -
a = params.discrimination
c = params.guess
common_factor = (qθ * (pθ - c)^2) / (pθ * (1 - c)^2)
common_factor * (a * a')
end
128 changes: 128 additions & 0 deletions src/next_item_rules/mirt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
abstract type MatrixScalarizer end

struct DeterminantScalarizer <: MatrixScalarizer end
(::DeterminantScalarizer)(mat) = det(mat)

struct TraceScalarizer <: MatrixScalarizer end
(::TraceScalarizer)(mat) = tr(mat)

abstract type StateCriteria end
abstract type ItemCriteria end

struct AbilityCovarianceStateCriteria{
DistEstT <: DistributionAbilityEstimator,
IntegratorT <: AbilityIntegrator
} <: StateCriteria
dist_est::DistEstT
integrator::IntegratorT
skip_zero::Bool
end

function AbilityCovarianceStateCriteria(bits...)
skip_zero = false
@requiresome (dist_est, integrator) = _get_dist_est_and_integrator(bits...)
return AbilityCovarianceStateCriteria(dist_est, integrator, skip_zero)
end

# XXX: Should be at type level
should_minimize(::AbilityCovarianceStateCriteria) = true

function (criteria::AbilityCovarianceStateCriteria)(
tracked_responses::TrackedResponses,
denom = normdenom(criteria.integrator,
criteria.dist_est,
tracked_responses)
)
if denom == 0.0 && criteria.skip_zero
return Inf
end
covariance_matrix(
criteria.integrator,
criteria.dist_est,
tracked_responses,
denom
)
end

struct ScalarizedStateCriteron{
StateCriteriaT <: StateCriteria,
MatrixScalarizerT <: MatrixScalarizer
} <: StateCriterion
criteria::StateCriteriaT
scalarizer::MatrixScalarizerT
end

function (ssc::ScalarizedStateCriteron)(tracked_responses)
res = ssc.criteria(tracked_responses) |> ssc.scalarizer
if !should_minimize(ssc.criteria)
res = -res
end
res
end

struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: ItemCriteria
ability_estimator::AbilityEstimatorT
expected_item_information::F
end

function InformationMatrixCriteria(ability_estimator)
InformationMatrixCriteria(ability_estimator, expected_item_information)
end

function init_thread(item_criterion::InformationMatrixCriteria,
responses::TrackedResponses)
# TODO: No need to do this one per thread. It just need to be done once per
# θ update.
# TODO: Update this to use track!(...) mechanism
ability = maybe_tracked_ability_estimate(responses, item_criterion.ability_estimator)
responses_information(responses.item_bank, responses.responses, ability)
end

function (item_criterion::InformationMatrixCriteria)(acc_info::Matrix{Float64},
tracked_responses::TrackedResponses,
item_idx)
# TODO: Add in information from the prior
ability = maybe_tracked_ability_estimate(
tracked_responses, item_criterion.ability_estimator)
return acc_info .+
item_criterion.expected_item_information(
ItemResponse(tracked_responses.item_bank, item_idx), ability)
end

should_minimize(::InformationMatrixCriteria) = false

struct ScalarizedItemCriteron{
ItemCriteriaT <: ItemCriteria,
MatrixScalarizerT <: MatrixScalarizer
} <: ItemCriterion
criteria::ItemCriteriaT
scalarizer::MatrixScalarizerT
end

function (ssc::ScalarizedItemCriteron)(tracked_responses, item_idx)
res = ssc.criteria(
init_thread(ssc.criteria, tracked_responses), tracked_responses, item_idx) |>
ssc.scalarizer
if !should_minimize(ssc.criteria)
res = -res
end
res
end

struct WeightedStateCriteria{InnerT <: StateCriteria} <: StateCriteria
weights::Vector{Float64}
criteria::InnerT
end

function (wsc::WeightedStateCriteria)(tracked_responses, item_idx)
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
end

struct WeightedItemCriteria{InnerT <: ItemCriteria} <: ItemCriteria
weights::Vector{Float64}
criteria::InnerT
end

function (wsc::WeightedItemCriteria)(tracked_responses, item_idx)
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
end
Loading

0 comments on commit 385052f

Please sign in to comment.