-
Notifications
You must be signed in to change notification settings - Fork 0
/
2_08_os.jl
104 lines (94 loc) · 2.28 KB
/
2_08_os.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
using DeepPumas, Pumas, CSV, DataFrames, StableRNGs, Random, CairoMakie, PumasPlots, DataFramesMeta
filepath = "data/tgi_os_data.csv"
df = DataFrame(CSV.File(filepath))
os_pop = read_pumas(
df,
observations = [:Death],
covariates = [:WT, :AGE, :SEX, :ECOG, :ALBB],
event_data = false,
)
os_cutoff = round(Int, 0.75 * length(os_pop))
os_tpop = os_pop[1:os_cutoff]
os_vpop = os_pop[os_cutoff+1:end]
os_model1 = @model begin
@param begin
NN_λ ∈ MLPDomain(
1, 15, 15, (1, softplus);
reg = L2(0.01; input = true, output = true, bias = true),
)
base_λ ∈ RealDomain(; lower=1e-8, init=1e-3)
end
@pre begin
λf = first ∘ NN_λ
_base_λ = base_λ
end
@dynamics begin
it' = 1.0
Λ' = λf(it / 2000) + _base_λ
end
@derived begin
_λf := first ∘ NN_λ
λ := @. _λf(t / 2000) + _base_λ
Death ~ @. TimeToEvent(λ, Λ)
end
end
os_model2 = @model begin
@param begin
NN_λ ∈ MLPDomain(
6, 15, 15, (1, softplus);
reg = L2(0.01; input = false, output = false, bias = false),
)
base_λ ∈ RealDomain(; lower=1e-8, init=1e-3)
end
@covariates begin
AGE
WT
ECOG
ALBB
SEX
end
@pre begin
λf = first ∘ NN_λ
_base_λ = base_λ
covs = (
AGE / 70,
WT / 100,
ECOG / 2,
ALBB / 200,
SEX == "Male" ? 0.0 : 1.0,
)
end
@dynamics begin
it' = 1.0
Λ' = λf(it / 2000, covs) + _base_λ
end
@derived begin
λf1 := λf[1]
λ := @. λf1(t / 2000, covs) + _base_λ
Death ~ @. TimeToEvent(λ, Λ)
end
end
nsubj = min(length(os_tpop), 150)
_os_tpop = os_tpop[1:nsubj]
fpm1 = fit(
os_model1,
_os_tpop,
sample_params(os_model1),
MAP(NaivePooled()),
optim_options = (; iterations=200, show_every=1),
diffeq_options = (; alg = Rodas5P()),
)
fpm2 = fit(
os_model2,
_os_tpop,
sample_params(os_model2),
MAP(NaivePooled()),
optim_options = (; iterations=200, show_every=1),
diffeq_options = (; alg = Rodas5P()),
)
# Training log likelihood
loglikelihood(os_model1, _os_tpop, coef(fpm1), NaivePooled())
loglikelihood(os_model2, _os_tpop, coef(fpm2), NaivePooled())
# Test log likelihood
loglikelihood(os_model1, os_vpop, coef(fpm1), NaivePooled())
loglikelihood(os_model2, os_vpop, coef(fpm2), NaivePooled())