C# Class Seq2SeqSharp.Tools.BaseSeq2SeqFramework

This is a framework for neural network training. It includes many core parts, such as backward propagation, parameters updates, memory management, computing graph managment, corpus shuffle & batching, I/O for model, logging & monitoring, checkpoints. You need to create your network inherited from this class, implmenet forward part only and pass it to TrainOneEpoch method for training
显示文件 Open project: SciSharp/Seq2SeqSharp

Public Methods

Method Description
BaseSeq2SeqFramework ( Array deviceIds, ProcessorTypeEnums processorType, string modelFilePath, float memoryUsageRatio = 0.9f, Array compilerOptions = null ) : AdvUtils
CreateComputGraph ( int deviceIdIdx, bool needBack = true ) : IComputeGraph
LoadModel ( Func InitializeParameters ) : IModelMetaData

Load model from given file

RunNetwork ( Func ForwardOnSingleDevice, List sntPairBatchs, int batchSplitFactor ) : AdvUtils
SaveModel ( IModelMetaData modelMetaData ) : bool

Private Methods

Method Description
CopyWeightsFromDefaultDeviceToAllOtherDevices ( ) : void

Copy weights from default device to all other devices

CreateCheckPoint ( IEnumerable validCorpus, List metrics, IModelMetaData modelMetaData, Func ForwardOnSingleDevice, double avgCostPerWordInTotal ) : void
GetParametersFromDefaultDevice ( ) : List
LoadParameters ( Stream stream ) : void
Register ( object childValue, string name ) : void
RegisterTrainableParameters ( object obj ) : void
RunTest ( List inputTokens, Func ForwardOnSingleDevice ) : List>
RunValid ( IEnumerable validCorpus, Func RunNetwork, List metrics, bool outputToFile = false ) : bool

Evaluate the quality of model on valid corpus.

RunValidParallel ( Func RunNetwork, List metrics, bool outputToFile, List srcSents, List refSents, List hypSents, List sntPairBatchs ) : void
SaveParameters ( Stream stream ) : void
SumGradientsToTensorsInDefaultDevice ( ) : void

Sum up gradients in all devices and keep them in the default device

TrainOneEpoch ( int ep, IEnumerable trainCorpus, IEnumerable validCorpus, ILearningRate learningRate, AdamOptimizer solver, List metrics, IModelMetaData modelMetaData, Func ForwardOnSingleDevice ) : void
ZeroGradientOnAllDevices ( ) : void

Method Details

BaseSeq2SeqFramework() public method

public BaseSeq2SeqFramework ( Array deviceIds, ProcessorTypeEnums processorType, string modelFilePath, float memoryUsageRatio = 0.9f, Array compilerOptions = null ) : AdvUtils
deviceIds Array
processorType ProcessorTypeEnums
modelFilePath string
memoryUsageRatio float
compilerOptions Array
return AdvUtils

CreateComputGraph() public method

public CreateComputGraph ( int deviceIdIdx, bool needBack = true ) : IComputeGraph
deviceIdIdx int
needBack bool
return IComputeGraph

LoadModel() public method

Load model from given file
public LoadModel ( Func InitializeParameters ) : IModelMetaData
InitializeParameters Func
return IModelMetaData

RunNetwork() public method

public RunNetwork ( Func ForwardOnSingleDevice, List sntPairBatchs, int batchSplitFactor ) : AdvUtils
ForwardOnSingleDevice Func
sntPairBatchs List
batchSplitFactor int
return AdvUtils

SaveModel() public method

public SaveModel ( IModelMetaData modelMetaData ) : bool
modelMetaData IModelMetaData
return bool