You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
234 lines
8.1 KiB
234 lines
8.1 KiB
/*
|
|
* Copyright 2021 Google LLC
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package com.google.ux.material.libmonet.quantize;
|
|
|
|
import static java.lang.Math.min;
|
|
|
|
import java.util.Arrays;
|
|
import java.util.LinkedHashMap;
|
|
import java.util.Map;
|
|
import java.util.Random;
|
|
|
|
/**
|
|
* An image quantizer that improves on the speed of a standard K-Means algorithm by implementing
|
|
* several optimizations, including deduping identical pixels and a triangle inequality rule that
|
|
* reduces the number of comparisons needed to identify which cluster a point should be moved to.
|
|
*
|
|
* <p>Wsmeans stands for Weighted Square Means.
|
|
*
|
|
* <p>This algorithm was designed by M. Emre Celebi, and was found in their 2011 paper, Improving
|
|
* the Performance of K-Means for Color Quantization. https://arxiv.org/abs/1101.0395
|
|
*/
|
|
public final class QuantizerWsmeans {
|
|
private QuantizerWsmeans() {}
|
|
|
|
private static final class Distance implements Comparable<Distance> {
|
|
int index;
|
|
double distance;
|
|
|
|
Distance() {
|
|
this.index = -1;
|
|
this.distance = -1;
|
|
}
|
|
|
|
@Override
|
|
public int compareTo(Distance other) {
|
|
return ((Double) this.distance).compareTo(other.distance);
|
|
}
|
|
}
|
|
|
|
private static final int MAX_ITERATIONS = 10;
|
|
private static final double MIN_MOVEMENT_DISTANCE = 3.0;
|
|
|
|
/**
|
|
* Reduce the number of colors needed to represented the input, minimizing the difference between
|
|
* the original image and the recolored image.
|
|
*
|
|
* @param inputPixels Colors in ARGB format.
|
|
* @param startingClusters Defines the initial state of the quantizer. Passing an empty array is
|
|
* fine, the implementation will create its own initial state that leads to reproducible
|
|
* results for the same inputs. Passing an array that is the result of Wu quantization leads
|
|
* to higher quality results.
|
|
* @param maxColors The number of colors to divide the image into. A lower number of colors may be
|
|
* returned.
|
|
* @return Map with keys of colors in ARGB format, values of how many of the input pixels belong
|
|
* to the color.
|
|
*/
|
|
public static Map<Integer, Integer> quantize(
|
|
int[] inputPixels, int[] startingClusters, int maxColors) {
|
|
// Uses a seeded random number generator to ensure consistent results.
|
|
Random random = new Random(0x42688);
|
|
|
|
Map<Integer, Integer> pixelToCount = new LinkedHashMap<>();
|
|
double[][] points = new double[inputPixels.length][];
|
|
int[] pixels = new int[inputPixels.length];
|
|
PointProvider pointProvider = new PointProviderLab();
|
|
|
|
int pointCount = 0;
|
|
for (int i = 0; i < inputPixels.length; i++) {
|
|
int inputPixel = inputPixels[i];
|
|
Integer pixelCount = pixelToCount.get(inputPixel);
|
|
if (pixelCount == null) {
|
|
points[pointCount] = pointProvider.fromInt(inputPixel);
|
|
pixels[pointCount] = inputPixel;
|
|
pointCount++;
|
|
|
|
pixelToCount.put(inputPixel, 1);
|
|
} else {
|
|
pixelToCount.put(inputPixel, pixelCount + 1);
|
|
}
|
|
}
|
|
|
|
int[] counts = new int[pointCount];
|
|
for (int i = 0; i < pointCount; i++) {
|
|
int pixel = pixels[i];
|
|
int count = pixelToCount.get(pixel);
|
|
counts[i] = count;
|
|
}
|
|
|
|
int clusterCount = min(maxColors, pointCount);
|
|
if (startingClusters.length != 0) {
|
|
clusterCount = min(clusterCount, startingClusters.length);
|
|
}
|
|
|
|
double[][] clusters = new double[clusterCount][];
|
|
int clustersCreated = 0;
|
|
for (int i = 0; i < startingClusters.length; i++) {
|
|
clusters[i] = pointProvider.fromInt(startingClusters[i]);
|
|
clustersCreated++;
|
|
}
|
|
|
|
int additionalClustersNeeded = clusterCount - clustersCreated;
|
|
if (additionalClustersNeeded > 0) {
|
|
for (int i = 0; i < additionalClustersNeeded; i++) {}
|
|
}
|
|
|
|
int[] clusterIndices = new int[pointCount];
|
|
for (int i = 0; i < pointCount; i++) {
|
|
clusterIndices[i] = random.nextInt(clusterCount);
|
|
}
|
|
|
|
int[][] indexMatrix = new int[clusterCount][];
|
|
for (int i = 0; i < clusterCount; i++) {
|
|
indexMatrix[i] = new int[clusterCount];
|
|
}
|
|
|
|
Distance[][] distanceToIndexMatrix = new Distance[clusterCount][];
|
|
for (int i = 0; i < clusterCount; i++) {
|
|
distanceToIndexMatrix[i] = new Distance[clusterCount];
|
|
for (int j = 0; j < clusterCount; j++) {
|
|
distanceToIndexMatrix[i][j] = new Distance();
|
|
}
|
|
}
|
|
|
|
int[] pixelCountSums = new int[clusterCount];
|
|
for (int iteration = 0; iteration < MAX_ITERATIONS; iteration++) {
|
|
for (int i = 0; i < clusterCount; i++) {
|
|
for (int j = i + 1; j < clusterCount; j++) {
|
|
double distance = pointProvider.distance(clusters[i], clusters[j]);
|
|
distanceToIndexMatrix[j][i].distance = distance;
|
|
distanceToIndexMatrix[j][i].index = i;
|
|
distanceToIndexMatrix[i][j].distance = distance;
|
|
distanceToIndexMatrix[i][j].index = j;
|
|
}
|
|
Arrays.sort(distanceToIndexMatrix[i]);
|
|
for (int j = 0; j < clusterCount; j++) {
|
|
indexMatrix[i][j] = distanceToIndexMatrix[i][j].index;
|
|
}
|
|
}
|
|
|
|
int pointsMoved = 0;
|
|
for (int i = 0; i < pointCount; i++) {
|
|
double[] point = points[i];
|
|
int previousClusterIndex = clusterIndices[i];
|
|
double[] previousCluster = clusters[previousClusterIndex];
|
|
double previousDistance = pointProvider.distance(point, previousCluster);
|
|
|
|
double minimumDistance = previousDistance;
|
|
int newClusterIndex = -1;
|
|
for (int j = 0; j < clusterCount; j++) {
|
|
if (distanceToIndexMatrix[previousClusterIndex][j].distance >= 4 * previousDistance) {
|
|
continue;
|
|
}
|
|
double distance = pointProvider.distance(point, clusters[j]);
|
|
if (distance < minimumDistance) {
|
|
minimumDistance = distance;
|
|
newClusterIndex = j;
|
|
}
|
|
}
|
|
if (newClusterIndex != -1) {
|
|
double distanceChange =
|
|
Math.abs(Math.sqrt(minimumDistance) - Math.sqrt(previousDistance));
|
|
if (distanceChange > MIN_MOVEMENT_DISTANCE) {
|
|
pointsMoved++;
|
|
clusterIndices[i] = newClusterIndex;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (pointsMoved == 0 && iteration != 0) {
|
|
break;
|
|
}
|
|
|
|
double[] componentASums = new double[clusterCount];
|
|
double[] componentBSums = new double[clusterCount];
|
|
double[] componentCSums = new double[clusterCount];
|
|
Arrays.fill(pixelCountSums, 0);
|
|
for (int i = 0; i < pointCount; i++) {
|
|
int clusterIndex = clusterIndices[i];
|
|
double[] point = points[i];
|
|
int count = counts[i];
|
|
pixelCountSums[clusterIndex] += count;
|
|
componentASums[clusterIndex] += (point[0] * count);
|
|
componentBSums[clusterIndex] += (point[1] * count);
|
|
componentCSums[clusterIndex] += (point[2] * count);
|
|
}
|
|
|
|
for (int i = 0; i < clusterCount; i++) {
|
|
int count = pixelCountSums[i];
|
|
if (count == 0) {
|
|
clusters[i] = new double[] {0., 0., 0.};
|
|
continue;
|
|
}
|
|
double a = componentASums[i] / count;
|
|
double b = componentBSums[i] / count;
|
|
double c = componentCSums[i] / count;
|
|
clusters[i][0] = a;
|
|
clusters[i][1] = b;
|
|
clusters[i][2] = c;
|
|
}
|
|
}
|
|
|
|
Map<Integer, Integer> argbToPopulation = new LinkedHashMap<>();
|
|
for (int i = 0; i < clusterCount; i++) {
|
|
int count = pixelCountSums[i];
|
|
if (count == 0) {
|
|
continue;
|
|
}
|
|
|
|
int possibleNewCluster = pointProvider.toInt(clusters[i]);
|
|
if (argbToPopulation.containsKey(possibleNewCluster)) {
|
|
continue;
|
|
}
|
|
|
|
argbToPopulation.put(possibleNewCluster, count);
|
|
}
|
|
|
|
return argbToPopulation;
|
|
}
|
|
}
|