1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
use variable::{Variable}; use std::borrow::Borrow; cpp! {{ #include <CNTKLibrary.h> #include <cstdio> #include <vector> using namespace CNTK; using namespace std; }} type DoubleParameterScheduleInner = [u64; 9usize]; pub struct DoubleParameterSchedule { pub(super) payload: DoubleParameterScheduleInner } impl DoubleParameterSchedule { pub fn constant(x: f64) -> DoubleParameterSchedule { DoubleParameterSchedule {payload: unsafe { cpp!([x as "double"] -> DoubleParameterScheduleInner as "TrainingParameterSchedule<double>" { return TrainingParameterPerSampleSchedule(x); }) }} } } impl Drop for DoubleParameterSchedule { fn drop(&mut self) { let payload = self.payload; unsafe { cpp!([payload as "TrainingParameterSchedule<double>"] { payload.~TrainingParameterSchedule(); }) }; } } type LearnerInner = [u64; 2usize]; #[derive(Debug)] pub struct Learner { pub(super) payload: LearnerInner } impl Learner { pub fn sgd<T: Borrow<Variable>>(parameters: &[T], learning_rate_schedule: &DoubleParameterSchedule) -> Learner { check_parameters(parameters); let data: Vec<Variable> = parameters.iter().map(|x| x.borrow().clone()).collect(); let data_ptr = data.as_ptr(); let data_size = data.len(); let schedule = learning_rate_schedule.payload; Learner { payload: unsafe { cpp!([data_ptr as "Parameter*", data_size as "size_t", schedule as "TrainingParameterSchedule<double>"] -> LearnerInner as "LearnerPtr" { return SGDLearner(vector<Parameter>(data_ptr, data_ptr + data_size), schedule); }) }} } pub fn momentum_sgd<T: Borrow<Variable>>(parameters: &[T], learning_rate_schedule: &DoubleParameterSchedule, momentum_schedule: &DoubleParameterSchedule) -> Learner { check_parameters(parameters); let data: Vec<Variable> = parameters.iter().map(|x| x.borrow().clone()).collect(); let data_ptr = data.as_ptr(); let data_size = data.len(); let schedule = learning_rate_schedule.payload; let mschedule = momentum_schedule.payload; Learner { payload: unsafe { cpp!([data_ptr as "Parameter*", data_size as "size_t", schedule as "TrainingParameterSchedule<double>", mschedule as "TrainingParameterSchedule<double>"] -> LearnerInner as "LearnerPtr" { return MomentumSGDLearner(vector<Parameter>(data_ptr, data_ptr + data_size), schedule, mschedule); }) }} } pub fn adam<T: Borrow<Variable>>(parameters: &[T], learning_rate_schedule: &DoubleParameterSchedule, momentum_schedule: &DoubleParameterSchedule) -> Learner { check_parameters(parameters); let data: Vec<Variable> = parameters.iter().map(|x| x.borrow().clone()).collect(); let data_ptr = data.as_ptr(); let data_size = data.len(); let schedule = learning_rate_schedule.payload; let mschedule = momentum_schedule.payload; Learner { payload: unsafe { cpp!([data_ptr as "Parameter*", data_size as "size_t", schedule as "TrainingParameterSchedule<double>", mschedule as "TrainingParameterSchedule<double>"] -> LearnerInner as "LearnerPtr" { return AdamLearner(vector<Parameter>(data_ptr, data_ptr + data_size), schedule, mschedule); }) }} } } impl Drop for Learner { fn drop(&mut self) { let payload = self.payload; unsafe { cpp!([payload as "LearnerPtr"] { payload.~LearnerPtr(); }) }; } } fn check_parameters<T: Borrow<Variable>>(parameters: &[T]) { for parameter in parameters { assert!(parameter.borrow().is_parameter()); } }