Program Listing for File gmm.h

Return to documentation for file (include/cpphots/clustering/gmm.h)

#ifndef CPPHOTS_CLUSTERING_GMM_H
#define CPPHOTS_CLUSTERING_GMM_H

#include <memory>

#include <blaze/Math.h>

#include "../types.h"
#include "../interfaces/clustering.h"
#include "utils.h"


// forward declarations from peregrine
template <typename T>
class Gmm_core;

template <typename T>
struct dataset;


namespace cpphots {

using BlazeVector = blaze::DynamicVector<TimeSurfaceScalarType, blaze::rowVector>;

using BlazeMatrix = blaze::DynamicMatrix<TimeSurfaceScalarType, blaze::rowMajor>;

class GMMClusterer : public interfaces::Clonable<GMMClusterer, interfaces::Clusterer>, public ClustererHistogramMixin, public ClustererOfflineMixin {

public:

    enum GMMType {
        S_GMM,
        U_S_GMM
    };

    GMMClusterer();

    GMMClusterer(GMMType type, uint16_t clusters, uint16_t truncated_clusters, uint16_t clusters_considered, TimeSurfaceScalarType eps, unsigned int max_iterations=100);

    uint16_t cluster(const TimeSurfaceType& surface) override;

    uint16_t getNumClusters() const override;

    void addCentroid(const TimeSurfaceType& centroid) override;

    const std::vector<TimeSurfaceType>& getCentroids() const override;

    void clearCentroids() override;

    bool hasCentroids() const override;

    void train(const std::vector<TimeSurfaceType>& tss) override;

    void toStream(std::ostream& out) const override;

    void fromStream(std::istream& in) override;

private:
    GMMType type;
    std::shared_ptr<Gmm_core<TimeSurfaceScalarType>> algo;
    uint16_t clusters, truncated_clusters, clusters_considered;
    BlazeMatrix mean;
    std::shared_ptr<dataset<TimeSurfaceScalarType>> set;
    size_t last_centroid = 0;
    bool learning = true;
    TimeSurfaceScalarType eps;
    unsigned int max_iterations;
    std::pair<uint16_t, uint16_t> ts_shape;

    mutable std::vector<TimeSurfaceType> converted_centroids;

    void fit();

    uint16_t predict(const BlazeVector& vec, int top_k = 1);

};

}

#endif