001package com.astrolabsoftware.FinkBrowser.HBaser.Clusteriser;
002
003import com.fasterxml.jackson.databind.ObjectMapper;
004import org.apache.commons.math3.linear.*;
005
006import java.util.List;
007import java.util.Arrays;
008import java.io.File;
009import java.io.IOException;
010
011// Log4J
012import org.apache.logging.log4j.Logger;
013import org.apache.logging.log4j.LogManager;
014
015/** <code>ClusterFinder</code> identifies HBase rows with
016  * clusters defined by previous clustering algorithm, read from
017  * <tt>JSON</tt> model files.
018  * @opt attributes
019  * @opt operations
020  * @opt types
021  * @opt visibility
022  * @author <a href="mailto:Julius.Hrivnac@cern.ch">J.Hrivnac</a> */
023public class ClusterFinder {
024  
025  public static void main(String[] args) throws IOException {
026    ClusterFinder finder = new ClusterFinder("/tmp/scaler_params.json",
027                                             "/tmp/pca_params.json",
028                                             "/tmp/cluster_centers.json");
029    double[] newData = {1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,
030                        1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,
031                        1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,
032                        1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,
033                        1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,};  // Example input
034    int cluster = finder.transformAndPredict(newData);
035    log.info("Assigned cluster: " + cluster);
036    }
037
038  public ClusterFinder(String scalerFile,
039                       String pcaFile,
040                       String clustersFile) throws IOException {
041    loadScalerParams(scalerFile);
042    loadPCAParams(pcaFile);
043    loadClusterCenters(clustersFile);
044    }
045  
046  private void loadScalerParams(String filePath) throws IOException {
047    ObjectMapper objectMapper = new ObjectMapper();
048    ScalerParams params = objectMapper.readValue(new File(filePath), ScalerParams.class);
049    _mean = params.mean;
050    _std = params.std;
051    log.debug("Scaler: " + _mean.length);
052    }
053  
054  private void loadPCAParams(String filePath) throws IOException {
055    ObjectMapper objectMapper = new ObjectMapper();
056    PCAParams params = objectMapper.readValue(new File(filePath), PCAParams.class);
057    _pcaComponents = new Array2DRowRealMatrix(params.components);
058    _explainedVariance = params.explained_variance;
059    log.debug("PCA Components: " + _pcaComponents.getColumnDimension() + " * " + _pcaComponents.getRowDimension());
060    }
061    
062  private void loadClusterCenters(String filePath) throws IOException {
063    ObjectMapper objectMapper = new ObjectMapper();
064    _clusterCenters = new Array2DRowRealMatrix(objectMapper.readValue(new File(filePath), double[][].class));
065    log.debug("Cluster Centers: " + _clusterCenters.getColumnDimension() + " * " + _clusterCenters.getRowDimension());
066    }    
067    
068  
069  private double[] standardize(double[] input) {
070    double[] standardized = new double[input.length];
071    for (int i = 0; i < input.length; i++) {
072      if (_std[i] == 0) {
073        standardized[i] = 0;
074        }
075      else {
076        standardized[i] = (input[i] - _mean[i]) / _std[i];
077        }
078      }
079    log.debug("Standardized: " + standardized.length);
080    return standardized;
081    }
082  
083  private double[] applyPCA(double[] standardizedInput) {
084    RealVector inputVector = new ArrayRealVector(standardizedInput);
085    RealVector transformed = _pcaComponents.transpose().operate(inputVector);
086    log.debug("PCA Transformed: " + transformed.getDimension());
087    return transformed.toArray();
088    }
089  
090  /** Find the closest cluster from the transformed data.
091    * @param  transformedData The transformed input data.
092    * @return                 The (number of) the closest cluster.
093    *                         <tt>-1</tt> if it cannot be found with sufficient resolution. */
094  private int findClosestCluster(double[] transformedData) {
095    RealVector transformedVector = new ArrayRealVector(transformedData);
096    double minDistance  = Double.MAX_VALUE;
097    double minDistance2 = Double.MAX_VALUE;
098    int closestCluster = -1;
099    RealVector clusterCenter;
100    double distance;
101    for (int i = 0; i < _clusterCenters.getRowDimension(); i++) {
102      clusterCenter = _clusterCenters.getRowVector(i);
103      distance = transformedVector.getDistance(clusterCenter);
104      if (distance < minDistance2) {
105        if (distance < minDistance) {
106          minDistance2  = minDistance;
107          minDistance   = distance;
108          closestCluster = i;
109          }
110        else {
111          minDistance2 = distance;
112          }
113        }
114      }
115    if (minDistance < _separation * minDistance2) {
116      return closestCluster;
117      }
118    return -1;
119    }
120  
121  /** Transform provided data array and find the closest cluster.
122    * @param  inputData The original input data.
123    * @return           The (number of) the closest cluster.
124    *                   <tt>-1</tt> if it cannot be found with sufficient resolution. */
125  public int transformAndPredict(double[] inputData) {
126    double[] standardized = standardize(inputData);
127    double[] pcaTransformed = applyPCA(standardized);
128    return findClosestCluster(pcaTransformed);
129    }
130    
131  /** Set the minimal separation quotient.
132    * @param separation The minimal separation quotient.
133    *                   The ration between distance to closest and second closest
134    *                   cluster should be smaller than <tt>separation</tt>,
135    *                   otherwise cluster is not considered reliable.
136    *                   <tt>1</tt> gives no restriction. The default is <tt>0.5</tt>. */
137  private static void setSeparation(double separation) {
138    _separation = separation;
139    }
140  
141  private static double _separation = 0.5;  
142    
143  private double[] _mean;
144  
145  private double[] _std;
146  
147  private RealMatrix _pcaComponents;
148  
149  private double[] _explainedVariance;
150  
151  private RealMatrix _clusterCenters;
152    
153  /** Logging . */
154  private static Logger log = LogManager.getLogger(ClusterFinder.class);
155  
156  }