for (int r = 0; r < n_rows; ++r) {
	for (int c = 0; c < n_cols; ++c) {
	unsigned char temp = 0;
	file.read((char*)&temp, sizeof(temp));
	data_dst[addr + width_image_input_CNN * (r + y_padding) + c + x_padding] = (temp / 255.0) * (scale_max - scale_min) + scale_min;
	}
	}
	}
	}
	static void readMnistLabels(std::string filename, double* data_dst, int num_image)
	{
	const double scale_max = 0.8;
	std::ifstream file(filename, std::ios::binary);
	assert(file.is_open());
	int magic_number = 0;
	int number_of_images = 0;
	file.read((char*)&magic_number, sizeof(magic_number));
	magic_number = reverseInt(magic_number);
	file.read((char*)&number_of_images, sizeof(number_of_images));
	number_of_images = reverseInt(number_of_images);
	assert(number_of_images == num_image);
	for (int i = 0; i < number_of_images; ++i) {
	unsigned char temp = 0;
	file.read((char*)&temp, sizeof(temp));
	data_dst[i * num_map_output_CNN + temp] = scale_max;
	}
	}
	bool CNN::getSrcData()
	{
	assert(data_input_train && data_output_train && data_input_test && data_output_test);
	std::string filename_train_images = "E:/GitCode/NN_Test/data/train-images.idx3-ubyte";
	std::string filename_train_labels = "E:/GitCode/NN_Test/data/train-labels.idx1-ubyte";
	readMnistImages(filename_train_images, data_input_train, num_patterns_train_CNN);
	readMnistLabels(filename_train_labels, data_output_train, num_patterns_train_CNN);
	std::string filename_test_images = "E:/GitCode/NN_Test/data/t10k-images.idx3-ubyte";
	std::string filename_test_labels = "E:/GitCode/NN_Test/data/t10k-labels.idx1-ubyte";
	readMnistImages(filename_test_images, data_input_test, num_patterns_test_CNN);
	readMnistLabels(filename_test_labels, data_output_test, num_patterns_test_CNN);
	return true;
	}
	bool CNN::train()
	{
	out2wi_S2.clear();
	out2bias_S2.clear();
	out2wi_S4.clear();
	out2bias_S4.clear();
	in2wo_C3.clear();
	weight2io_C3.clear();
	bias2out_C3.clear();
	in2wo_C1.clear();
	weight2io_C1.clear();
	bias2out_C1.clear();
	calc_out2wi(width_image_C1_CNN, height_image_C1_CNN, width_image_S2_CNN, height_image_S2_CNN, num_map_S2_CNN, out2wi_S2);
	calc_out2bias(width_image_S2_CNN, height_image_S2_CNN, num_map_S2_CNN, out2bias_S2);
	calc_out2wi(width_image_C3_CNN, height_image_C3_CNN, width_image_S4_CNN, height_image_S4_CNN, num_map_S4_CNN, out2wi_S4);
	calc_out2bias(width_image_S4_CNN, height_image_S4_CNN, num_map_S4_CNN, out2bias_S4);
	calc_in2wo(width_image_C3_CNN, height_image_C3_CNN, width_image_S4_CNN, height_image_S4_CNN, num_map_C3_CNN, num_map_S4_CNN, in2wo_C3);
	calc_weight2io(width_image_C3_CNN, height_image_C3_CNN, width_image_S4_CNN, height_image_S4_CNN, num_map_C3_CNN, num_map_S4_CNN, weight2io_C3);
	calc_bias2out(width_image_C3_CNN, height_image_C3_CNN, width_image_S4_CNN, height_image_S4_CNN, num_map_C3_CNN, num_map_S4_CNN, bias2out_C3);
	calc_in2wo(width_image_C1_CNN, height_image_C1_CNN, width_image_S2_CNN, height_image_S2_CNN, num_map_C1_CNN, num_map_C3_CNN, in2wo_C1);
	calc_weight2io(width_image_C1_CNN, height_image_C1_CNN, width_image_S2_CNN, height_image_S2_CNN, num_map_C1_CNN, num_map_C3_CNN, weight2io_C1);
	calc_bias2out(width_image_C1_CNN, height_image_C1_CNN, width_image_S2_CNN, height_image_S2_CNN, num_map_C1_CNN, num_map_C3_CNN, bias2out_C1);
	int iter = 0;
	for (iter = 0; iter < num_epochs_CNN; iter++) {
	std::cout << "epoch: " << iter + 1;
	for (int i = 0; i < num_patterns_train_CNN; i++) {
	data_single_image = data_input_train + i * num_neuron_input_CNN;
	data_single_label = data_output_train + i * num_neuron_output_CNN;
	Forward_C1();
	Forward_S2();
	Forward_C3();
	Forward_S4();
	Forward_C5();
	Forward_output();
	Backward_output();
	Backward_C5();
	Backward_S4();
	Backward_C3();
	Backward_S2();
	Backward_C1();
	Backward_input();
	UpdateWeights();
	}
	double accuracyRate = test();
	std::cout << ", accuray rate: " << accuracyRate << std::endl;
	if (accuracyRate > accuracy_rate_CNN) {
	saveModelFile("E:/GitCode/NN_Test/data/cnn.model");
	std::cout << "generate cnn model" << std::endl;
	break;
	}
	}
	if (iter == num_epochs_CNN) {
	saveModelFile("E:/GitCode/NN_Test/data/cnn.model");
	std::cout << "generate cnn model" << std::endl;
	}
	return true;
	}
	double CNN::activation_function_tanh(double x)
	{
	double ep = std::exp(x);
	double em = std::exp(-x);
	return (ep - em) / (ep + em);
	}
	double CNN::activation_function_tanh_derivative(double x)
	{
	return (1.0 - x * x);
	}
	double CNN::activation_function_identity(double x)
	{
	return x;
	}
	double CNN::activation_function_identity_derivative(double x)
	{
	return 1;
	}
	double CNN::loss_function_mse(double y, double t)
	{
	return (y - t) * (y - t) / 2;
	}
	double CNN::loss_function_mse_derivative(double y, double t)
	{
	return (y - t);
	}
 電子發(fā)燒友App
                        電子發(fā)燒友App
                     
                 
                 
           
        
 
        











 
            
             
             
                 
             工商網(wǎng)監(jiān)
工商網(wǎng)監(jiān)
        
評(píng)論