<template>
  <b-card :class="{ 'bg-light-yellow': isEditing, 'bg-light-green': !isEditing }" header-tag="header">
    <div slot="header">
      <span v-if="isEditing"><i class="fa icon-plus" /> {{ $t('global.edit') }} </span>
      <span v-else><i class="fa icon-pencil" /> {{ $t('global.add_new') }} </span>
      <b-button size="sm" @click="setDefaults">
        {{ $t('global.set_defaults') }}
      </b-button>
      <div class="card-actions">
        <a href="#" class="btn btn-setting" @click="$emit('cancel', true)">
          <i class="fa icon-close" />
        </a>
      </div>
    </div>
    <form @submit.prevent="save">
      <b-form-group :label="$t('global.name')" :state="!errors.name" :invalid-feedback="errors.name">
        <b-form-input v-model="trainingData.name" type="text" :required="true" />
      </b-form-group>
      <b-form-group :label="$t('model_training.init_from_id')">
        <b-form-select v-model="trainingData.config.init_from_id" :options="optionsInitFrom" />
      </b-form-group>
      <b-form-group>
        <b-form-checkbox v-model="trainingData.config.freeze_convolutions" :required="true">
          {{ $t('model_training.freeze_convolutions') }}
        </b-form-checkbox>
      </b-form-group>
      <b-form-group :label="$t('model_training.batch_size_train')">
        <b-form-input v-model.number="trainingData.config.batch_size_train" type="text" :required="true" />
      </b-form-group>
      <b-form-group :label="$t('model_training.batch_size_val')">
        <b-form-input v-model.number="trainingData.config.batch_size_val" type="text" :required="true" />
      </b-form-group>
      <b-form-group :label="$t('model_training.number_of_epochs')">
        <b-form-input v-model.number="trainingData.config.number_of_epochs" type="text" :required="true" />
      </b-form-group>
      <b-form-group :label="$t('model_training.learning_rate')">
        <b-form-input v-model.number="trainingData.config.learning_rate" type="text" :required="true" />
      </b-form-group>
      <b-form-group :label="$t('model_training.max_class_weight')">
        <b-form-input v-model.number="trainingData.config.max_class_weight" type="text" :required="true" />
        <b-form-text v-if="classBalancingStats">
          {{ $t('model_training.oversampling_threshold') }}: {{ classBalancingStats.oversamplingThreshold }}
        </b-form-text>
        <b-form-text v-if="classBalancingStats">
          {{ $t('model_training.calculated_min_class_weight') }}: {{ classBalancingStats.minClassWeight }}
        </b-form-text>
        <b-form-text v-if="classBalancingStats">
          {{ $t('model_training.calculated_max_class_weight') }}: {{ classBalancingStats.maxClassWeight }}
        </b-form-text>
      </b-form-group>
      <div>
        <b-btn v-if="isEditing || showCancel" variant="outline-danger" @click="$emit('cancel', true)">
          {{ $t('global.cancel') }}
        </b-btn>
        <b-btn v-if="isEditing" class="pull-right" type="submit" variant="success">
          {{ $t('global.save') }}
        </b-btn>
        <b-btn v-else class="pull-right" type="submit" variant="success">
          {{ $t('global.save') }}
        </b-btn>
      </div>
    </form>
  </b-card>
</template>
<script>

import { VuexTypes } from '@/store/types';
import { mapActions, mapState } from 'vuex';

export default {
  name: 'ModelTrainingForm',

  props: {
    model: {
      type: Object,
      required: true,
    },
    training: {
      type: Object,
      required: true,
    },
    allTrainings: {
      type: Array,
      required: true,
    },
    modelId: {
      type: Number,
      required: true,
    },
    showCancel: {
      type: Boolean,
      required: false,
      default() {
        return false;
      },
    },
  },
  data() {
    return {
      trainingData: _.cloneDeep(this.training),
      errors: {},
      defaultConfig: {
        init_from_id: null,
        freeze_convolutions: false,
        batch_size_train: 5,
        batch_size_val: 5,
        number_of_epochs: 200,
        learning_rate: 0.0003,
        max_class_weight: 4,
      },
      statuses: [
        { value: null, text: '' },
        { value: 'created', text: 'created' },
        { value: 'sent_for_preparing', text: 'sent_for_preparing' },
        { value: 'training', text: 'training' },
        { value: 'testing', text: 'testing' },
        { value: 'ready', text: 'ready' },
        { value: 'queue', text: 'queue' },
      ],
    };
  },
  computed: {
    ...mapState(['modelTraining']),
    isEditing() {
      return typeof this.trainingData.id !== 'undefined' && this.trainingData.id > 0;
    },
    optionsInitFrom() {
      // Exclude self and trainings with 0 progress
      let valid = this.allTrainings.filter((item) => {
        return item.id !== this.trainingData.id && item.progress > 0;
      }).map((item) => ({
        value: item.id,
        text: `${item.id} - ${item.name} (${item.progress}e)`,
      }));

      // Add the null option
      valid.unshift({ value: null, text: this.$t('global.none') });

      return valid;
    },
    classBalancingStats() {
      // This is a port of the classification trainer logic that calculates
      // the class weights and balancing threshold
      let classes = this.model?.model_classes;
      if (!classes || !this.trainingData.config.max_class_weight) {
        return null;
      }

      let initEstimate = this.invClassWeight(this.trainingData.config.max_class_weight, 0);
      let estimate = initEstimate;
      let bias = 0;
      for (let i = 0; i < 10; i++) {
        bias = 0;
        for (const c of this.model.model_classes) {
          bias += Math.max(0, estimate - c.segmented_images_train_count);
        }
        bias = Math.ceil(bias);
        estimate = this.invClassWeight(this.trainingData.config.max_class_weight, bias);
      }
      
      let oversamplingThreshold = Math.ceil(estimate);

      let minClassWeight = Number.POSITIVE_INFINITY;
      let maxClassWeight = Number.NEGATIVE_INFINITY;
      for (const c of this.model.model_classes) {
        let weight = this.classWeight(Math.max(c.segmented_images_train_count, oversamplingThreshold), bias);
        minClassWeight = Math.min(minClassWeight, weight);
        maxClassWeight = Math.max(maxClassWeight, weight);
      }

      return {
        oversamplingThreshold,
        minClassWeight,
        maxClassWeight,
      };
    },
  },
  methods: {
    ...mapActions({
      add: VuexTypes.MODEL_TRAINING_ADD,
      update: VuexTypes.MODEL_TRAINING_UPDATE,
    }),
    classWeight(num_samples, bias) {
      let classes = this.model?.model_classes;
      let total = classes.map((item) => item.segmented_images_train_count).reduce((a, b) => a + b, 0);
      return (total + bias) / (classes.length * num_samples);
    },
    invClassWeight(weight, bias) {
      // The function is symmetric across the diagonal, so it is its own inverse function
      return this.classWeight(weight, bias);
    },
    save() {
      this.trainingData.model_id = this.modelId;
      const action = this.isEditing ? this.update(this.trainingData) : this.add(this.trainingData);

      action
        .then((response) => {
          this.errors = {};
          this.$emit('saved', Object.assign({}, response.data));
        }).catch((error) => {
          this.errors = {};
          if (typeof error.data !== 'undefined') {
            this.errors = error.data;
          }
          for (const item in error.data) {
            this.errors[item] = this.errors[item].join(' ');
          }
        });
    },
    setDefaults() {
      this.trainingData.config = Object.assign({}, this.defaultConfig);
    },
  },
};
</script>
<style></style>
