Skip to content

Commit

Permalink
added probguess
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas Weber committed Feb 11, 2009
1 parent 2d1bb20 commit 9a52026
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions numpredict.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,26 @@ def costf(scale):
return costf


def probguess(data, vec1, low, high, k=5, weightfun=gaussianweight):
"""Returns the probability that the result for input vec1 is in the
interval [low, hight], based on the trainingdata data."""
dlist = getdistances(data, vec1)
nweight = 0.0 # weight of neighbors in interval
tweight = 0.0 # weight of all neighbors ("total weight")

for i in range(k):
dist = dlist[i][0]
idx = dlist[i][1]
weight = weightfun(dist)
v = data[idx]['result']

if low <= v <= high:
nweight += weight
tweight += weight
if tweight == 0.0: return 0.0
return nweight / tweight


if __name__ == '__main__':
s = wineset1(50)

Expand All @@ -157,6 +177,7 @@ def costf(scale):
print crossvalidate(lambda d, v: knnestimate(d, v, k=7), s)
print crossvalidate(lambda d, v: weightedknn(d, v, k=5), s)

# Use optimization to automatically rescale different dimensions
print
print 'set 2, not-to-scale parameters (XXX buggy, broken, incomplete)'
s = wineset2(50)
Expand All @@ -170,6 +191,15 @@ def costf(scale):
#print optimization.annealingoptimize([(0, 20)] * 4,
#createcostfunction(knnestimate, s), step=2)

# This shows that tracking distributions is worthwile
print
print 'set 3, uneven distribution'
s = wineset3()
print probguess(s, [99, 20], 20, 120)
print probguess(s, [99, 20], 120, 1000)
print probguess(s, [99, 20], 40, 80)
print probguess(s, [99, 20], 80, 120)
print 'real price:', wineprice(99.0, 20.0)
print 'estimated price:', weightedknn(s, [99.0, 20.0])
print 'crossvalidation error:', crossvalidate(weightedknn, s)

0 comments on commit 9a52026

Please sign in to comment.