
#include "photos.h"
#include <queue>
using namespace std;

int a=0;
int b=0;
int c=0;
int d=0;

bool chrono_dist(int p1, int p2, double &dist) {
	dist = fabs((double)(p2-p1));
	return true;
}

double pca_dist(int p1, int f1, int p2, int f2) {
	char filename[1000];
	char target[1000];
	char buffer[1000];
	sprintf(filename, DATA_DIR "pca_distances/%d_%d.sfi", p1, f1);
	sprintf(target, "%d_%d.sfi", p2, f2);
	ifstream in(filename);
	if ( !in.good() ) {
		fprintf(stderr, "error opening %s\n", filename);
		exit(-1);
	}
	string line = "";
	getline(in, line);
	while ( line != "" ) {
		strcpy(buffer, line.c_str());
		//printf("%s <-> %s\n", buffer, target);
		char* ds = strchr(buffer, '\t');
		if ( ds != NULL ) {
			*(ds++) = 0;
			if ( strcmp(target,buffer)==0 ) {
				in.close();
				double d = 0.0;
				sscanf(ds, "%lf", &d);
				return d;
			}
		}
		getline(in, line);
	}
	in.close();
	fprintf(stderr, "error comparing face distances (%d.%d <-> %d.%d)\n", p1, f1, p2, f2);
	exit(-1);
	return 0.0;
}

bool face_pca_dist(int p1, int p2, double &dist) {
	if ( photos[p1].getFaceCount()<=0 || photos[p2].getFaceCount()<=0 ) {
		return false;
	}
	dist = pca_dist(
		photos[p1].getFaceID(),
		photos[p1].getFaceID(0),
		photos[p2].getFaceID(),
		photos[p2].getFaceID(0)
	);
	for ( int i=0; i<photos[p1].getFaceCount(); i++ ) {
		for ( int j=0; j<photos[p2].getFaceCount(); j++ ) {
			double d = pca_dist(
				photos[p1].getFaceID(),
				photos[p1].getFaceID(i),
				photos[p2].getFaceID(),
				photos[p2].getFaceID(j)
			);
			if ( d<dist ) {
				dist = d;
			}
		}
	}
	return true;
}

bool color_l2_dist(int p1, int p2, double& dist) {
	p1 = photos[p1].getID();
	p2 = photos[p2].getID();
	char filename[1000];
	sprintf(filename, DATA_DIR "color_dist/%d/%d_%d.txt", p1, p1, p2);
	ifstream in(filename);
	if (!in.good()) {
		fprintf(stderr, "error opening %s\n", filename);
		exit(-1);
	}
	string line = "";
	getline(in, line);
	getline(in, line);
	line = line.substr(4,line.length()-4);
	sscanf(line.c_str(), "%lf", &dist);
	in.close();
	return true;
}

bool cumulative_color_l2_dist(int p1, int p2, double& dist) {
	p1 = photos[p1].getID();
	p2 = photos[p2].getID();
	char filename[1000];
	sprintf(filename, DATA_DIR "cumulative_color_distances/%d/%d_%d.txt", p1, p1, p2);
	ifstream in(filename);
	if (!in.good()) {
		fprintf(stderr, "error opening %s\n", filename);
		exit(-1);
	}
	string line = "";
	getline(in, line);
	getline(in, line);
	line = line.substr(4,line.length()-4);
	sscanf(line.c_str(), "%lf", &dist);
	in.close();
	return true;
}

void knn(bool (*comp)(int p1, int p2, double &dist), string category, int n) {
	vector<int> seen;
	for ( int i=0; i<photos.size(); i++ ) {
		vector< pair<double,int> > neighbor_list;
		priority_queue< pair<double,int> > neighbors;
		for ( int j=0; j<seen.size(); j++ ) {
			double d = 0.0;
			if ( comp(i,j,d) ) {
				neighbors.push(pair<double,int>(-d,j));
			}
		}
		int sum = 0;
		int count = 0;
		while (count < n && !neighbors.empty()) {
			pair<double,int> neighbor = neighbors.top();
			if (photos[neighbor.second].hasKeyword(category)) {
				sum++;
			}
			count++;
			neighbor_list.push_back(neighbor);
			neighbors.pop();
		}
		bool prediction = false;
		if ( count > 0 && ((double)sum)/count >= 0.5 ) {
			prediction = true;
		}

		bool actual = photos[i].hasKeyword(category);
		printf(prediction?"+":"-");
		printf(actual?"+":"-");
		printf(" (%d->%lf) [", photos[i].getID(), count<=0 ? 0 : ((double)sum)/count);
		for ( int j=0; j<neighbor_list.size(); j++ ) {
			if ( j ) {
				printf(",");
			}
			printf("%d:%lf", photos[neighbor_list[j].second].getID(), -neighbor_list[j].first);
		}
		printf("]\n");

		if ( prediction ) {
			if ( actual ) {
				a++;
			} else {
				b++;
			}
		} else {
			if ( actual ) {
				c++;
			} else {
				d++;
			}
		}

		seen.push_back(i);
	}
	printf("%d %d %d %d\n", a, b, c, d);
}

void usage() {
	fprintf(stderr, "usage: knn <comp-func> <category> <N>\n");
	fprintf(stderr, "  where <comp-func> is:\n");
	fprintf(stderr, "    chrono - chronological distance\n");
	fprintf(stderr, "    colorl2 - color L2 distance\n");
	fprintf(stderr, "    ccolorl2 - cumulative color L2 distance\n");
	fprintf(stderr, "    facepca - face PCA distance\n");
	exit(-1);
}

int main(int argc, char** argv) {
	if ( argc != 4 ) {
		usage();
	}
	bool (*comp)(int p1, int p2, double& dist) = 0;
	if ( strcmp(argv[1],"chrono")==0 ) {
		comp = chrono_dist;
	} else if ( strcmp(argv[1],"colorl2")==0 ) {
		comp = color_l2_dist;
	} else if ( strcmp(argv[1],"ccolorl2")==0 ) {
		comp = cumulative_color_l2_dist;
	} else if ( strcmp(argv[1],"facepca")==0 ) {
		comp = face_pca_dist;
	} else {
		usage();
	}
	string category = argv[2];
	int n = atoi(argv[3]);

	loadPhotos();
	knn(comp, category, n);
}

