-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbadtest.py
executable file
·87 lines (59 loc) · 1.85 KB
/
badtest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#! /usr/bin/python3
import numpy
import random
import badrbm
training_set = list()
for i in range(5):
cur = numpy.zeros((5, 1))
cur[i] = 1.0
training_set.append(cur)
# append an all-on vector to the training set
training_set.append( numpy.ones((5,1)) )
r = badrbm.rbm(5, 6, 0.1, p=0.1)
count = 0
while count < 10 ** 6:
cur = random.choice(training_set)
r.apply_update(cur)
count += 1
print("trained")
# it seems to get stuck repeating one sample, which /kind of/ makes sense
# so, run multiple trials, make sure it will at least get stuck in any training sample that shows up
def demented_energy_guess(rbm, v, samples=20):
energies = []
while len(energies) < samples:
cur_h = rbm.get_h(v)
cur_e = rbm.get_energy(v, cur_h)
energies.append(cur_e)
energy = sum(energies) / len(energies)
return energy
def get_random_sample(prob=0.5):
r = numpy.random.rand(5,1)
for i in range(len(r)):
if r[i] < prob:
r[i] = 1.0
else:
r[i] = 0.0
return r
with open("samples.dat", "w") as ofile:
#earlies = []
trial = 0
while trial < 40:
ofile.write("trial {}\n".format(trial))
samples = r.get_samples(10)
#earlies.append(samples[1])
for sample in samples:
ofile.write(str(sample.T) + "\n")
ofile.write("\n\n")
trial += 1
ofile.write("energies:\n")
for s in training_set:
e = demented_energy_guess(r, s)
ofile.write("{} {}\n".format(s.T, e))
# for s in earlies:
# e = demented_energy_guess(r, s)
# ofile.write("{} {}\n".format(s.T, e))
for _ in range(20):
s = get_random_sample()
e = demented_energy_guess(r, s)
ofile.write("{} {}\n".format(s.T, e))
ofile.write("\n\nW:\n{}\n\na:\n{}\n\nb:\n{}\n".format(r.W, r.a, r.b))