1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
use variable::{Variable};
use std::borrow::Borrow;

cpp! {{
  #include <CNTKLibrary.h>
  #include <cstdio>
  #include <vector>

  using namespace CNTK;
  using namespace std;
}}

type DoubleParameterScheduleInner = [u64; 9usize];

pub struct DoubleParameterSchedule {
    pub(super) payload: DoubleParameterScheduleInner
}

impl DoubleParameterSchedule {
    pub fn constant(x: f64) -> DoubleParameterSchedule {
        DoubleParameterSchedule {payload: unsafe {
            cpp!([x as "double"] -> DoubleParameterScheduleInner as "TrainingParameterSchedule<double>" {
                return TrainingParameterPerSampleSchedule(x);
            })
        }}
    }
}

impl Drop for DoubleParameterSchedule {
    fn drop(&mut self) {
        let payload = self.payload;
        unsafe {
            cpp!([payload as "TrainingParameterSchedule<double>"] {
                payload.~TrainingParameterSchedule();
            })
        };
    }
}

type LearnerInner = [u64; 2usize];

#[derive(Debug)]
pub struct Learner {
    pub(super) payload: LearnerInner
}

impl Learner {
    pub fn sgd<T: Borrow<Variable>>(parameters: &[T], learning_rate_schedule: &DoubleParameterSchedule) -> Learner {
        check_parameters(parameters);

        let data: Vec<Variable> = parameters.iter().map(|x| x.borrow().clone()).collect();
        let data_ptr = data.as_ptr();
        let data_size = data.len();
        let schedule = learning_rate_schedule.payload;
        Learner { payload: unsafe {
            cpp!([data_ptr as "Parameter*", data_size as "size_t", schedule as "TrainingParameterSchedule<double>"] -> LearnerInner as "LearnerPtr" {
                return SGDLearner(vector<Parameter>(data_ptr, data_ptr + data_size), schedule);
            })
        }}
    }

    pub fn momentum_sgd<T: Borrow<Variable>>(parameters: &[T], learning_rate_schedule: &DoubleParameterSchedule, momentum_schedule: &DoubleParameterSchedule) -> Learner {
        check_parameters(parameters);

        let data: Vec<Variable> = parameters.iter().map(|x| x.borrow().clone()).collect();
        let data_ptr = data.as_ptr();
        let data_size = data.len();
        let schedule = learning_rate_schedule.payload;
        let mschedule = momentum_schedule.payload;
        Learner { payload: unsafe {
            cpp!([data_ptr as "Parameter*", data_size as "size_t", schedule as "TrainingParameterSchedule<double>", mschedule as "TrainingParameterSchedule<double>"] -> LearnerInner as "LearnerPtr" {
                return MomentumSGDLearner(vector<Parameter>(data_ptr, data_ptr + data_size), schedule, mschedule);
            })
        }}
    }

    pub fn adam<T: Borrow<Variable>>(parameters: &[T], learning_rate_schedule: &DoubleParameterSchedule, momentum_schedule: &DoubleParameterSchedule) -> Learner {
        check_parameters(parameters);

        let data: Vec<Variable> = parameters.iter().map(|x| x.borrow().clone()).collect();
        let data_ptr = data.as_ptr();
        let data_size = data.len();
        let schedule = learning_rate_schedule.payload;
        let mschedule = momentum_schedule.payload;
        Learner { payload: unsafe {
            cpp!([data_ptr as "Parameter*", data_size as "size_t", schedule as "TrainingParameterSchedule<double>", mschedule as "TrainingParameterSchedule<double>"] -> LearnerInner as "LearnerPtr" {
                return AdamLearner(vector<Parameter>(data_ptr, data_ptr + data_size), schedule, mschedule);
            })
        }}
    }
}

impl Drop for Learner {
    fn drop(&mut self) {
        let payload = self.payload;
        unsafe {
            cpp!([payload as "LearnerPtr"] {
                payload.~LearnerPtr();
            })
        };
    }
}

fn check_parameters<T: Borrow<Variable>>(parameters: &[T]) {
    for parameter in parameters {
        assert!(parameter.borrow().is_parameter());
    }
}