Program Listing for File gmm.h
↰ Return to documentation for file (include/cpphots/clustering/gmm.h)
#ifndef CPPHOTS_CLUSTERING_GMM_H
#define CPPHOTS_CLUSTERING_GMM_H
#include <memory>
#include <blaze/Math.h>
#include "../types.h"
#include "../interfaces/clustering.h"
#include "utils.h"
// forward declarations from peregrine
template <typename T>
class Gmm_core;
template <typename T>
struct dataset;
namespace cpphots {
using BlazeVector = blaze::DynamicVector<TimeSurfaceScalarType, blaze::rowVector>;
using BlazeMatrix = blaze::DynamicMatrix<TimeSurfaceScalarType, blaze::rowMajor>;
class GMMClusterer : public interfaces::Clonable<GMMClusterer, interfaces::Clusterer>, public ClustererHistogramMixin, public ClustererOfflineMixin {
public:
enum GMMType {
S_GMM,
U_S_GMM
};
GMMClusterer();
GMMClusterer(GMMType type, uint16_t clusters, uint16_t truncated_clusters, uint16_t clusters_considered, TimeSurfaceScalarType eps, unsigned int max_iterations=100);
uint16_t cluster(const TimeSurfaceType& surface) override;
uint16_t getNumClusters() const override;
void addCentroid(const TimeSurfaceType& centroid) override;
const std::vector<TimeSurfaceType>& getCentroids() const override;
void clearCentroids() override;
bool hasCentroids() const override;
void train(const std::vector<TimeSurfaceType>& tss) override;
void toStream(std::ostream& out) const override;
void fromStream(std::istream& in) override;
private:
GMMType type;
std::shared_ptr<Gmm_core<TimeSurfaceScalarType>> algo;
uint16_t clusters, truncated_clusters, clusters_considered;
BlazeMatrix mean;
std::shared_ptr<dataset<TimeSurfaceScalarType>> set;
size_t last_centroid = 0;
bool learning = true;
TimeSurfaceScalarType eps;
unsigned int max_iterations;
std::pair<uint16_t, uint16_t> ts_shape;
mutable std::vector<TimeSurfaceType> converted_centroids;
void fit();
uint16_t predict(const BlazeVector& vec, int top_k = 1);
};
}
#endif