import scipy.sparse as sps
import tensorflow as tf
import numpy as np
import time
from antk.core import loader
import os
import datetime
import matplotlib.pyplot as plt
# ============================================================================================
# ============================CONVENIENCE DICTIONARY==========================================
# ============================================================================================
OPT = {'adam': tf.train.AdamOptimizer,
       'ada': tf.train.AdagradOptimizer,
       'grad': tf.train.GradientDescentOptimizer,
       'mom': tf.train.MomentumOptimizer}
# ============================================================================================
# ============================GLOBAL MODULE FUNCTIONS=========================================
# ============================================================================================
[docs]def get_feed_list(batch, placeholderdict, supplement=None, dropouts=None, dropout_flag='train'):
    """
    :param batch: A dataset object.
    :param placeholderdict: A dictionary where the keys match keys in batch, and the values are placeholder tensors
    :param supplement: A dictionary of numpy input matrices with keys corresponding to placeholders in placeholderdict, where the row size of the matrices do not correspond to the number of datapoints. For use with input data intended for `embedding_lookup`_.
    :param dropouts: Dropout tensors in graph.
    :param dropout_flag: Whether to use Dropout probabilities for feed forward.
    :return: A feed dictionary with keys of placeholder tensors and values of numpy matrices, paired by key
    """
    ph, dt = [], []
    datadict = batch.features.copy()
    datadict.update(batch.labels)
    if supplement:
        datadict.update(supplement)
    for desc in placeholderdict:
        ph.append(placeholderdict[desc])
        if sps.issparse(datadict[desc]):
            dt.append(datadict[desc].todense().astype(float, copy=False))
        elif type(datadict[desc]) is loader.HotIndex:
            dt.append(datadict[desc].vec)
        else:
            dt.append(datadict[desc])
    if dropouts:
        for prob in dropouts:
            ph.append(prob[0])
            if dropout_flag == 'train':
                dt.append(prob[1])
            elif dropout_flag == 'eval':
                dt.append(1.0)
            else:
                raise ValueError('dropout_flag must be "train" or "eval". Found %s' % dropout_flag)
    return {i: d for i, d in zip(ph, dt)} 
[docs]def parse_summary_val(summary_str):
    """
    Helper function to parse numeric value from tf.scalar_summary
    :param summary_str: Return value from running session on tf.scalar_summary
    :return: A dictionary containing the numeric values.
    """
    summary_proto = tf.Summary()
    summary_proto.ParseFromString(summary_str)
    summaries = {}
    for val in summary_proto.value:
        summaries[val.tag] = val.simple_value
    return summaries 
# ============================================================================================
# ============================GENERIC MODEL CLASS=============================================
# ============================================================================================
[docs]class Model(object):
    """
    Generic model builder for training and predictions.
    :param objective: Loss function
    :param placeholderdict: A dictionary of placeholders
    :param maxbadcount: For early stopping
    :param momentum: The momentum for tf.MomentumOptimizer
    :param mb: The mini-batch size
    :param verbose: Whether to print dev error, and save_tensor evals
    :param epochs: maximum number of epochs to train for.
    :param learnrate: learnrate for gradient descent
    :param save: Save best model to *best_model_path*.
    :param opt: Optimization strategy. May be 'adam', 'ada', 'grad', 'momentum'
    :param decay: Parameter for decaying learn rate.
    :param evaluate: Evaluation metric
    :param predictions: Predictions selected from feed forward pass.
    :param logdir: Where to put the tensorboard data.
    :param random_seed: Random seed for TensorFlow initializers.
    :param model_name: Name for model
    :param clip_gradients: The limit on gradient size. If 0.0 no clipping is performed.
    :param make_histograms: Whether or not to make histograms for model weights and activations
    :param best_model_path: File to save best model to during training.
    :param save_tensors: A hashmap of str:Tensor mappings. Tensors are evaluated during training. Evaluations of these tensors on best model are accessible via property :any:`evaluated_tensors`.
    :param tensorboard: Whether to make tensorboard histograms of weights and activations, and graphs of dev_error.
    :return: :any:`Model`
    """
    def __init__(self, objective, placeholderdict,
                 maxbadcount=20, momentum=None, mb=1000, verbose=True,
                 epochs=50, learnrate=0.003, save=False, opt='grad',
                 decay=[1, 1.0], evaluate=None, predictions=None,
                 logdir='log/', random_seed=None, model_name='generic',
                 clip_gradients=0.0, make_histograms=False,
                 best_model_path='/tmp/model.ckpt',
                 save_tensors={}, tensorboard=False, train_evaluate=None):
        self.objective = objective
        for t in tf.get_collection('losses'):
            self.objective += t
        self._placeholderdict = placeholderdict
        self.maxbadcount = maxbadcount
        self.momentum = momentum
        self.mb = mb
        self.verbose = verbose
        self.epochs = epochs
        self.learnrate = learnrate
        self.save = save
        self.opt = opt
        self.decay = decay
        self.epoch_times = []
        self.evaluate = evaluate
        self.train_evaluate = train_evaluate
        self._best_dev_error = float('inf')
        self.predictor = predictions
        self.random_seed = random_seed
        self.session = tf.Session()
        if self.random_seed is not None:
            tf.set_random_seed(self.random_seed)
        self.model_name = model_name
        self.clip_gradients = clip_gradients
        self.tensorboard = tensorboard
        self.make_histograms = make_histograms
        if self.make_histograms:
            self.tensorboard = True
        self.histogram_summaries = []
        if not logdir.endswith('/'):
            self.logdir = logdir + '/'
        else:
            self.logdir = logdir
        os.system('mkdir ' + self.logdir)
        self.save_tensors = save_tensors
        self._completed_epochs = 0.0
        self._best_completed_epochs = 0.0
        self._evaluated_tensors = {}
        self.deverror = []
        self._badcount = 0
        self.batch = tf.Variable(0)
        self.train_eval = []
        self.dev_spot = []
        self.train_spot = []
        # ================================================================
        # ======================For tensorboard===========================
        # ================================================================
        if tensorboard:
            self._init_summaries()
        # =============================================================================
        # ===================OPTIMIZATION STRATEGY=====================================
        # =============================================================================
        optimizer = OPT[self.opt]
        decay_step = self.decay[0]
        decay_rate = self.decay[1]
        global_step = tf.Variable(0, trainable=False) #keeps track of the mini-batch iteration
        if not (decay_step == 1 and decay_rate == 1.0):
            self.learnrate = tf.train.exponential_decay(self.learnrate, self.batch*self.mb,
                                                   decay_step, decay_rate, name='learnrate_decay')
        if self.clip_gradients > 0.0:
            params = tf.trainable_variables()
            self.gradients = tf.gradients(self.objective, params)
            if self.clip_gradients > 0.0:
                self.gradients, self.gradients_norm = tf.clip_by_global_norm(
                    self.gradients, self.clip_gradients)
            grads_and_vars = zip(self.gradients, params)
            if self.opt == 'mom':
                self.train_step = optimizer(self.learnrate,
                                       self.momentum).apply_gradients(grads_and_vars,
                                                                      global_step=self.batch,
                                                                      name="train")
            else:
                self.train_step = optimizer(self.learnrate).apply_gradients(grads_and_vars,
                                                                       global_step=self.batch,
                                                                       name="train")
        else:
            if self.opt == 'mom':
                self.train_step = optimizer(self.learnrate,
                                       self.momentum).minimize(self.objective,
                                                               global_step=self.batch)
            else:
                self.train_step = optimizer(self.learnrate).minimize(self.objective,
                                                                global_step=self.batch)
        # =============================================================================
        # ===================Initialize graph =====================================
        # =============================================================================
        self.session.run(tf.initialize_all_variables())
        if save:
            self.saver = tf.train.Saver()
            self.best_model_path = best_model_path
            self.save_path = self.saver.save(self.session, self.best_model_path)
    # ======================================================================
    # ================Properites============================================
    # ======================================================================
    @property
    def placeholderdict(self):
        '''
        Dictionary of model placeholders
        '''
        return self._placeholderdict
    @property
    def best_dev_error(self):
        """
        The best dev error reached during training.
        """
        return self._best_dev_error
    @property
    def average_secs_per_epoch(self):
        """
        The average number of seconds to complete an epoch.
        """
        return np.sum(np.array(self.epoch_times))/self._completed_epochs
    @property
    def evaluated_tensors(self):
        '''
        A dictionary of evaluations on best model for tensors and keys specified by *save_tensors* argument to constructor.
        '''
        return self._evaluated_tensors
    
    @property
    def completed_epochs(self):
        '''
        Number of epochs completed during training (fractional)
        '''
        return self._completed_epochs
    @property
    def best_completed_epochs(self):
        '''
        Number of epochs completed during at point of best dev eval during training (fractional)
        '''
        return self._best_completed_epochs
[docs]    def plot_train_dev_eval(self, figure_file='testfig.pdf'):
        plt.plot(self.dev_spot, self.deverror, label='dev')
        plt.plot(self.train_spot, self.train_eval, label='train')
        plt.ylabel('Error')
        plt.xlabel('Epoch')
        plt.legend(loc='upper right')
        plt.savefig(figure_file) 
[docs]    def predict(self, data, supplement=None):
        """
        :param data:  :any:`DataSet` to make predictions from.
        :return: A set of predictions from feed forward defined by :any:`self.predictions`
        """
        fd = get_feed_list(data, self.placeholderdict, supplement=supplement,
                           dropouts=tf.get_collection('dropout_prob'),
                           dropout_flag='eval')
        return self.session.run(self.predictor,
                                feed_dict=fd) 
[docs]    def eval(self, tensor_in, data, supplement=None):
        """
        Evaluation of model.
        :param data: :any:`DataSet` to evaluate on.
        :return: Result of evaluating on data for :any:`self.evaluate`
        """
        fd = get_feed_list(data, self.placeholderdict, supplement=supplement,
                           dropouts=tf.get_collection('dropout_prob'),
                           dropout_flag='eval')
        return self.session.run(tensor_in, feed_dict=fd) 
[docs]    def train(self, train, dev=None, supplement=None, eval_schedule='epoch', train_dev_eval_factor = 3):
        """
        :param data: :any:`DataSet` to train on.
        :return: A trained :any:`Model`
        """
        self._completed_epochs = 0.0
        if self.save:
            self.saver.restore(self.session, self.best_model_path)
        # ========================================================
        # ===========Check data to see if dev eval================
        # ========================================================
        if eval_schedule == 'epoch':
            eval_schedule = train.num_examples
        self._badcount = 0
        start_time = time.time()
        # ============================================================================================
        # =============================TRAINING=======================================================
        # ============================================================================================
        counter = 0
        train_eval_counter = 0
        while self._completed_epochs < self.epochs: # keeps track of the epoch iteration
            # ==============PER MINI-BATCH=====================================
            newbatch = train.next_batch(self.mb)
            fd = get_feed_list(newbatch, self.placeholderdict, supplement,
                               dropouts=tf.get_collection('dropout_prob'))
            self.session.run(self.train_step, feed_dict=fd)
            counter += self.mb
            train_eval_counter += self.mb
            self._completed_epochs += float(self.mb)/float(train.num_examples)
            if self.train_evaluate and train_eval_counter >= train_dev_eval_factor*eval_schedule:
                self.train_eval.append(self.eval(self.evaluate, train, supplement))
                self.train_spot.append(self._completed_epochs)
                if np.isnan(self.train_eval[-1]):
                    print("Aborting training...train evaluates to nan.")
                    break
                if self.verbose:
                    print("epoch: %f train eval: %.10f" % (self._completed_epochs, self.train_eval[-1]))
                train_eval_counter = 0
            if (counter >= eval_schedule or self._completed_epochs >= self.epochs):
                #=================PER eval_schedule==================================
                self._log_summaries(dev, supplement)
                counter = 0
                if dev:
                    self.deverror.append(self.eval(self.evaluate, dev, supplement))
                    self.dev_spot.append(self._completed_epochs)
                    if np.isnan(self.deverror[-1]):
                        print("Aborting training...dev evaluates to nan.")
                        break
                    if self.verbose:
                        print("epoch: %f dev error: %.10f" % (self._completed_epochs, self.deverror[-1]))
                    for tname in self.save_tensors:
                        self._evaluated_tensors[tname] = self.eval(self.save_tensors[tname], dev, supplement)
                        if self.verbose:
                            print("\t%s: %s" % (tname, self._evaluated_tensors[tname]))
                    # ================Early Stopping====================================
                    if self.deverror[-1] < self.best_dev_error:
                        self._badcount = 0
                        self._best_dev_error = self.deverror[-1]
                        if self.save:
                            self.save_path = self.saver.save(self.session, self.best_model_path)
                        self._best_completed_epochs = self._completed_epochs
                    else:
                        self._badcount += 1
                    if self._badcount > self.maxbadcount:
                        print('badcount exceeded: %d' % self._badcount)
                        break
                    # ==================================================================
            self.epoch_times.append(time.time() - start_time)
            start_time = time.time() 
    # ================================================================
    # ======================For tensorboard===========================
    # ================================================================
    def _init_summaries(self):
        if self.make_histograms:
            self.histogram_summaries.extend(map(tf.histogram_summary,
                                      [var.name for var in tf.trainable_variables()],
                                      tf.trainable_variables()))
            self.histogram_summaries.extend(map(tf.histogram_summary,
                                      ['normalization/'+n.name for n in tf.get_collection('normalized_activations')],
                                      tf.get_collection('normalized_activations')))
            self.histogram_summaries.extend(map(tf.histogram_summary,
                                      ['activation/'+a.name for a in tf.get_collection('activation_layers')],
                                      tf.get_collection('activation_layers')))
        self.loss_summary = tf.scalar_summary('Loss', self.objective)
        self.dev_error_summary = tf.scalar_summary('dev_error', self.evaluate)
        summary_directory = os.path.join(self.logdir,
                                         self.model_name + '-' +
                                         datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
        self._summary_writer = tf.train.SummaryWriter(summary_directory,
                                                      self.session.graph.as_graph_def())
    def _log_summaries(self,  dev, supplement):
        fd = get_feed_list(dev, self.placeholderdict, supplement=supplement,
                           dropouts=tf.get_collection('dropout_prob'),
                           dropout_flag='eval')
        if self.tensorboard:
            if self.make_histograms:
                sum_str = self.session.run(self.histogram_summaries, fd)
                for summary in sum_str:
                    self._summary_writer.add_summary(summary, self._completed_epochs)
            loss_sum_str = self.session.run(self.loss_summary, fd)
            self._summary_writer.add_summary(loss_sum_str, self._completed_epochs)
        if dev:
            if self.tensorboard:
                dev_sum_str = self.session.run(self.dev_error_summary, fd)
                self._summary_writer.add_summary(dev_sum_str, self._completed_epochs)