00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 #include "KDTree.h"
00022 #include <limits>
00023 #include <list>
00024 #include <vector>
00025
00026 namespace KDTreeSpace
00027 {
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 template <class KE, class VTYPE>
00038 KDTree<KE, VTYPE>::KDTree(const ItemVector_t& iElemsList, int iDimensions) :
00039 _pivot(NULL), _leftKD(NULL), _rightKD(NULL), _dims(iDimensions)
00040 {
00041
00042 if (!iElemsList.size())
00043 {
00044 return;
00045 }
00046
00047
00048 ItemPtrVector_t aElemsPtrList;
00049 ItemVectorIt_t aElemsIt = iElemsList.begin();
00050 ItemVectorIt_t aElemsItEnd = iElemsList.end();
00051 for(; aElemsIt != aElemsItEnd; ++aElemsIt)
00052 {
00053 const KE& aTmp = *aElemsIt;
00054 aElemsPtrList.push_back(&aTmp);
00055 }
00056 init(aElemsPtrList);
00057 }
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 template <class KE, class VTYPE>
00068 KDTree<KE, VTYPE>::KDTree(const ItemPtrVector_t& iElemsPtrList, int iDimensions) :
00069 _pivot(NULL), _leftKD(NULL), _rightKD(NULL), _dims(iDimensions)
00070 {
00071 init(iElemsPtrList);
00072 }
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082 template <class KE, class VTYPE>
00083 KDTree<KE, VTYPE>::~KDTree()
00084 {
00085 delete(_leftKD);
00086 delete(_rightKD);
00087 }
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097 template <class KE, class VTYPE>
00098 void KDTree<KE, VTYPE>::init(const ItemPtrVector_t& iElemsPtrList)
00099 {
00100
00101 typename std::vector<const KE*>::const_iterator aPivotIt = choosePivot(iElemsPtrList);
00102 _pivot = *aPivotIt;
00103
00104
00105 ItemPtrVector_t aLeftElems, aRightElems;
00106
00107
00108 ItemPtrVectorIt_t aElemsIt = iElemsPtrList.begin();
00109 ItemPtrVectorIt_t aElemsItEnd = iElemsPtrList.end();
00110 VTYPE aPivotDimValue = _pivot->getVectorElem(_splitDim);
00111 for(; aElemsIt != aElemsItEnd; ++aElemsIt)
00112 {
00113
00114 if (aElemsIt == aPivotIt)
00115 {
00116 continue;
00117 }
00118
00119 if ((*aElemsIt)->getVectorElem(_splitDim) <= aPivotDimValue)
00120 {
00121 aLeftElems.push_back(*aElemsIt);
00122 }
00123 else
00124 {
00125 aRightElems.push_back(*aElemsIt);
00126 }
00127 }
00128
00129
00130
00131
00132
00133 if (aLeftElems.size())
00134 {
00135 _leftKD = new KDTree<KE, VTYPE> (aLeftElems, _dims);
00136 }
00137 else
00138 {
00139 _leftKD = NULL;
00140 }
00141
00142 if (aRightElems.size())
00143 {
00144 _rightKD = new KDTree<KE, VTYPE> (aRightElems, _dims);
00145 }
00146 else
00147 {
00148 _rightKD = NULL;
00149 }
00150
00151 }
00152
00153
00154
00155
00156
00157
00158
00159
00160
00161
00162 template <class KE, class VTYPE>
00163 typename std::vector<const KE*>::const_iterator KDTree<KE, VTYPE>::choosePivot(const ItemPtrVector_t& iElemsPtrList)
00164 {
00165
00166 std::vector<VTYPE> aMinVals(_dims, std::numeric_limits<VTYPE>::max());
00167 std::vector<VTYPE> aMaxVals(_dims, - std::numeric_limits<VTYPE>::max());
00168
00169 ItemPtrVectorIt_t aElemsIt = iElemsPtrList.begin();
00170 ItemPtrVectorIt_t aElemsEnd = iElemsPtrList.end();
00171 for(; aElemsIt != aElemsEnd; ++aElemsIt)
00172 {
00173 for (int aDim = 0; aDim < _dims; ++aDim)
00174 {
00175 VTYPE& aCurValue = (*aElemsIt)->getVectorElem(aDim);
00176 if (aCurValue < aMinVals[aDim])
00177 {
00178 aMinVals[aDim] = aCurValue;
00179 }
00180 if (aCurValue > aMaxVals[aDim])
00181 {
00182 aMaxVals[aDim] = aCurValue;
00183 }
00184 }
00185 }
00186
00187
00188 _splitDim = 0;
00189 VTYPE aLargestRangeValue = - std::numeric_limits<VTYPE>::max();
00190 for (int aDim = 0; aDim < _dims; ++aDim)
00191 {
00192 if ( (aMaxVals[aDim] - aMinVals[aDim]) > aLargestRangeValue )
00193 {
00194 _splitDim = aDim;
00195 aLargestRangeValue = aMaxVals[aDim] - aMinVals[aDim];
00196 }
00197 }
00198
00199
00200 VTYPE aMedian = aLargestRangeValue / 2 + aMinVals[_splitDim];
00201
00202
00203 typename std::vector<const KE*>::const_iterator aClosestElemIt;
00204 VTYPE aClosestDiff = std::numeric_limits<VTYPE>::max();
00205 for(aElemsIt = iElemsPtrList.begin(); aElemsIt != aElemsEnd; ++aElemsIt)
00206 {
00207 VTYPE aCurValue = fabs((*aElemsIt)->getVectorElem(_splitDim) - aMedian);
00208 if (aCurValue < aClosestDiff)
00209 {
00210 aClosestDiff = aCurValue;
00211 aClosestElemIt = aElemsIt;
00212 }
00213 }
00214
00215 return aClosestElemIt;
00216 }
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226 template <class KE, class VTYPE>
00227 double KDTree<KE, VTYPE>::calcSqDist(const KE* i1, const KE* i2)
00228 {
00229 double aDist = 0.0;
00230 for (int n = 0 ; n < _dims ; ++n)
00231 {
00232 double aDiff = i1->getVectorElem(n) - i2->getVectorElem(n);
00233 aDist += aDiff * aDiff;
00234 }
00235 return (aDist);
00236 }
00237
00238
00239 template <class KE, class VTYPE>
00240 std::set<BestMatch<KE>, std::greater<BestMatch<KE> > > KDTree<KE, VTYPE>::getNearestNeighboursBBF(const KE& iTarget, int iNbBestMatches, int iNbSearchSteps)
00241 {
00242
00243 HyperRectangle<KE, VTYPE> aHR(_dims);
00244
00245
00246 BestMatchLimitedSet_t aBestMatches(iNbBestMatches);
00247
00248
00249 std::list<QueueEntry<KE, VTYPE> > aSearchQueue;
00250
00251
00252 recurseNearestNeighboursBBF(iTarget, aHR, aBestMatches, aSearchQueue, iNbSearchSteps);
00253
00254 return aBestMatches.getSet();
00255 }
00256
00257 template <class KE, class VTYPE>
00258 void KDTree<KE, VTYPE>::recurseNearestNeighboursBBF(const KE& iTarget,
00259 HyperRectangle<KE, VTYPE> & iHR,
00260 BestMatchLimitedSet_t& ioBestMatches,
00261 QueueEntryList_t& ioSearchQueue,
00262 int& ioRemainingUnqueues)
00263 {
00264
00265 ioBestMatches.insert(BestMatch<KE>(_pivot, calcSqDist(&iTarget, _pivot)));
00266
00267
00268
00269
00270
00271
00272
00273 HyperRectangle<KE, VTYPE> aLeftHR(iHR);
00274 HyperRectangle<KE, VTYPE> aRightHR(iHR);
00275
00276
00277
00278 if (!iHR.split(aLeftHR, aRightHR, _splitDim, _pivot->getVectorElem(_splitDim)))
00279 {
00280 return;
00281 }
00282
00283
00284
00285
00286
00287
00288
00289
00290 KDTree<KE, VTYPE> * aNearKD = _leftKD;
00291 KDTree<KE, VTYPE> * aFarKD = _rightKD;
00292 HyperRectangle<KE, VTYPE> & aNearHR = aLeftHR;
00293 HyperRectangle<KE, VTYPE> & aFarHR = aRightHR;
00294
00295
00296
00297 if (iTarget.getVectorElem(_splitDim) > _pivot->getVectorElem(_splitDim))
00298 {
00299 aNearKD = _rightKD;
00300 aFarKD = _leftKD;
00301 aNearHR = aRightHR;
00302 aFarHR = aLeftHR;
00303 }
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313 if (aFarKD)
00314 {
00315 ioSearchQueue.push_back(QueueEntry<KE, VTYPE>(aFarHR, aFarKD, aFarHR.calcSqDistance(iTarget)));
00316 }
00317
00318
00319
00320
00321
00322
00323 if (aNearKD)
00324 {
00325
00326 aNearKD->recurseNearestNeighboursBBF(iTarget, aNearHR, ioBestMatches, ioSearchQueue, ioRemainingUnqueues);
00327 }
00328
00329 else if (ioRemainingUnqueues > 0 && ioBestMatches.size() && ioSearchQueue.size())
00330 {
00331
00332
00333
00334 ioRemainingUnqueues--;
00335
00336
00337 double aHyperSphereRadius = ioBestMatches.begin()->_distance;
00338
00339
00340 typename std::list<QueueEntry<KE, VTYPE> >::iterator aSQ, aSmallestIt;
00341 double aSmallestDist = std::numeric_limits<double>::max();
00342 for (aSQ = ioSearchQueue.begin(); aSQ!= ioSearchQueue.end(); ++aSQ)
00343 {
00344 if (aSQ->_dist < aSmallestDist)
00345 {
00346 aSmallestDist = aSQ->_dist;
00347 aSmallestIt = aSQ;
00348 }
00349 }
00350 QueueEntry<KE, VTYPE> aQueueElem = *(aSmallestIt);
00351 ioSearchQueue.erase(aSmallestIt);
00352
00353
00354
00355
00356 if (aQueueElem._HR.hasHyperSphereIntersect(iTarget, aHyperSphereRadius))
00357 {
00358 aQueueElem._kdTree->recurseNearestNeighboursBBF(iTarget, aQueueElem._HR, ioBestMatches, ioSearchQueue, ioRemainingUnqueues);
00359 }
00360 else
00361 {
00362
00363 }
00364 }
00365 }
00366
00367
00368
00369
00370
00371 template <class KE, class TYPE>
00372 HyperRectangle<KE, TYPE>::HyperRectangle() :
00373 _dim(0)
00374 {
00375
00376 }
00377
00378
00379
00380 template <class KE, class TYPE>
00381 HyperRectangle<KE, TYPE>::HyperRectangle(int iDim) :
00382 _leftTop(std::vector<TYPE>(iDim, -std::numeric_limits<TYPE>::max())),
00383 _rightBottom(std::vector<TYPE>(iDim, std::numeric_limits<TYPE>::max())),
00384 _dim(iDim)
00385 {
00386
00387 }
00388
00389
00390 template <class KE, class TYPE>
00391 HyperRectangle<KE, TYPE>::HyperRectangle(HyperRectangle& iOther) :
00392 _leftTop(std::vector<TYPE>(iOther._leftTop)),
00393 _rightBottom(std::vector<TYPE>(iOther._rightBottom)),
00394 _dim(iOther._dim)
00395 {
00396
00397 }
00398
00399
00400
00401
00402
00403 template <class KE, class TYPE>
00404 bool HyperRectangle<KE, TYPE>::split(HyperRectangle& oLeft, HyperRectangle& oRight, int iSplitDim, TYPE iSplitVal)
00405 {
00406
00407 if (_leftTop[iSplitDim] >= iSplitVal || _rightBottom[iSplitDim] < iSplitVal)
00408 {
00409 return false;
00410 }
00411
00412
00413
00414
00415
00416
00417
00418
00419
00420
00421 oLeft._rightBottom[iSplitDim] = iSplitVal;
00422 oRight._leftTop[iSplitDim] = iSplitVal;
00423
00424 return true;
00425 }
00426
00427 template <class KE, class TYPE>
00428 double HyperRectangle<KE, TYPE>::calcSqDistance (const KE& iTarget)
00429 {
00430 double aSqDistance = 0;
00431
00432
00433 for (int n = 0 ; n < _dim ; ++n)
00434 {
00435 TYPE aTargetVal = iTarget.getVectorElem(n);
00436 TYPE aHRMin = _leftTop[n];
00437 TYPE aHRMax = _rightBottom[n];
00438
00439 double aDimDist = aTargetVal;
00440 if (aTargetVal <= aHRMin)
00441 {
00442 aDimDist = aTargetVal - aHRMin;
00443 }
00444 else if (aTargetVal > aHRMin && aTargetVal < aHRMax)
00445 {
00446 aDimDist = 0;
00447 }
00448 else if (aTargetVal >= aHRMax)
00449 {
00450 aDimDist = aTargetVal - aHRMax;
00451 }
00452 aSqDistance += aDimDist * aDimDist;
00453 }
00454
00455 return aSqDistance;
00456 }
00457
00458 template <class KE, class TYPE>
00459 bool HyperRectangle<KE, TYPE>::hasHyperSphereIntersect(const KE& iTarget, double iSqDistance)
00460 {
00461
00462
00463
00464 double aDist = calcSqDistance(iTarget);
00465
00466
00467
00468
00469 return (aDist < iSqDistance);
00470 }
00471
00472 template <class KE, class TYPE>
00473 void HyperRectangle<KE, TYPE>::display()
00474 {
00475 for (int n = 0 ; n < _dim ; ++n)
00476 {
00477 if (_leftTop[n] > -std::numeric_limits<TYPE>::max()
00478 || _rightBottom[n] < std::numeric_limits<TYPE>::max())
00479 {
00480 std::cout << "dim[" << n << "] = {" << _leftTop[n] << " , " << _rightBottom[n] << "}" << std::endl;
00481 }
00482 }
00483 std::cout << std::endl;
00484 }
00485
00486 template <class KE, class TYPE>
00487 bool HyperRectangle<KE, TYPE>::isTargetIn (const KE& iTarget)
00488 {
00489 if (iTarget.getVectorSize() != _dim)
00490 {
00491 std::cout << "is target in dimension mismatch" << std::endl;
00492 }
00493
00494 for (int n = 0 ; n < _dim ; ++n)
00495 {
00496 TYPE aDimVal = iTarget.getVectorElem(n);
00497 if (aDimVal <= _leftTop[n] || aDimVal >= _rightBottom[n])
00498 {
00499 return (false);
00500 }
00501 }
00502
00503 return true;
00504 }
00505
00506 template <class KE, class VTYPE>
00507 bool operator < (const QueueEntry<KE, VTYPE> & iA, const QueueEntry<KE, VTYPE> & iB)
00508 {
00509 return (iA._dist < iB._dist);
00510 }
00511
00512 template <class KE>
00513 bool operator > (const BestMatch<KE> & iA, const BestMatch<KE> & iB)
00514 {
00515 return (iA._distance > iB._distance);
00516 }
00517
00518 }
00519
00520
00521