C# 클래스 Seq2SeqSharp.Tools.BaseSeq2SeqFramework

This is a framework for neural network training. It includes many core parts, such as backward propagation, parameters updates, memory management, computing graph managment, corpus shuffle & batching, I/O for model, logging & monitoring, checkpoints. You need to create your network inherited from this class, implmenet forward part only and pass it to TrainOneEpoch method for training
파일 보기 프로젝트 열기: SciSharp/Seq2SeqSharp

공개 메소드들

메소드 설명
BaseSeq2SeqFramework ( Array deviceIds, ProcessorTypeEnums processorType, string modelFilePath, float memoryUsageRatio = 0.9f, Array compilerOptions = null ) : AdvUtils
CreateComputGraph ( int deviceIdIdx, bool needBack = true ) : IComputeGraph
LoadModel ( Func InitializeParameters ) : IModelMetaData

Load model from given file

RunNetwork ( Func ForwardOnSingleDevice, List sntPairBatchs, int batchSplitFactor ) : AdvUtils
SaveModel ( IModelMetaData modelMetaData ) : bool

비공개 메소드들

메소드 설명
CopyWeightsFromDefaultDeviceToAllOtherDevices ( ) : void

Copy weights from default device to all other devices

CreateCheckPoint ( IEnumerable validCorpus, List metrics, IModelMetaData modelMetaData, Func ForwardOnSingleDevice, double avgCostPerWordInTotal ) : void
GetParametersFromDefaultDevice ( ) : List
LoadParameters ( Stream stream ) : void
Register ( object childValue, string name ) : void
RegisterTrainableParameters ( object obj ) : void
RunTest ( List inputTokens, Func ForwardOnSingleDevice ) : List>
RunValid ( IEnumerable validCorpus, Func RunNetwork, List metrics, bool outputToFile = false ) : bool

Evaluate the quality of model on valid corpus.

RunValidParallel ( Func RunNetwork, List metrics, bool outputToFile, List srcSents, List refSents, List hypSents, List sntPairBatchs ) : void
SaveParameters ( Stream stream ) : void
SumGradientsToTensorsInDefaultDevice ( ) : void

Sum up gradients in all devices and keep them in the default device

TrainOneEpoch ( int ep, IEnumerable trainCorpus, IEnumerable validCorpus, ILearningRate learningRate, AdamOptimizer solver, List metrics, IModelMetaData modelMetaData, Func ForwardOnSingleDevice ) : void
ZeroGradientOnAllDevices ( ) : void

메소드 상세

BaseSeq2SeqFramework() 공개 메소드

public BaseSeq2SeqFramework ( Array deviceIds, ProcessorTypeEnums processorType, string modelFilePath, float memoryUsageRatio = 0.9f, Array compilerOptions = null ) : AdvUtils
deviceIds Array
processorType ProcessorTypeEnums
modelFilePath string
memoryUsageRatio float
compilerOptions Array
리턴 AdvUtils

CreateComputGraph() 공개 메소드

public CreateComputGraph ( int deviceIdIdx, bool needBack = true ) : IComputeGraph
deviceIdIdx int
needBack bool
리턴 IComputeGraph

LoadModel() 공개 메소드

Load model from given file
public LoadModel ( Func InitializeParameters ) : IModelMetaData
InitializeParameters Func
리턴 IModelMetaData

RunNetwork() 공개 메소드

public RunNetwork ( Func ForwardOnSingleDevice, List sntPairBatchs, int batchSplitFactor ) : AdvUtils
ForwardOnSingleDevice Func
sntPairBatchs List
batchSplitFactor int
리턴 AdvUtils

SaveModel() 공개 메소드

public SaveModel ( IModelMetaData modelMetaData ) : bool
modelMetaData IModelMetaData
리턴 bool