Tengyu Ma Facebook AI Research Based on joint work with Yuanzhi Li (Princeton) and Hongyang Zhang (Stanford)
Ø Over-parameterization: # parameters # examples Ø a set of parameters that can Ø fit to training data and generalize to test data Ø or fit to real inputs with random labels, and fail to generalize Ø or fit to training data but fail to generalize This talk: analysis for simpler models that share the properties above (matrix sensing, and neural nets with quadratic activations)
Ø Uniform convergence doesn t hold Ø training loss test loss for all parameters Ø a model that can fit to training data but fail to generalize test loss training loss Ø Algorithm matters: multiple local/global minima exist, the algorithm chooses the one that generalizes Ø different algorithms converge to local min of training loss, but generalize differently [Keskar et al 16, Wilson et al 17, Dinh et al 17] Ø Post-mortem explanations: margin theory, PAC-Bayes, and compressionbased bounds [Bartlett et al. 17, Neyshabur et al. 17, Arora et al. 18, Dziugaite and Roy 18]
Algorithms matter Stochastic gradient descent, with proper initialization and learning rate, prefers an optimal solution with low complexity, when it exists Ø # parameters is almost irrelevant Intrinsic complexity of the data matters Ø This talk: rigorous argument for matrix sensing and quadratic neural networks
Ø! data " #,, " & R ), *! Ø Claim: minimize,(.) with gradient descent starting with. = 0 is equivalent to solve Ø Gradient descent is limited to search in a subspace Ø Related: GD on logistic loss converges to max margin solution [Soudry et al. 17, Ji&Telgarsky 17] span(" #,, " & )
Ø! data points " #,, " & R ) ) from standard normal dist. Ø Unknown PSD matrix + R ) ) of rank,. Ø We observe / 0 = " 0, + Ø Variable 4 R ) ) min U f(u) = nx (y i ha i,uu > i) 2 i=1 Ø Focus: gradient descent 4 56# = 4 5 89:(4 5 ) Ø Well-studied problem with efficient solutions [Recht et al. 10, Candes et al 07, Tu et al 2015, Zheng and Lafferty 15 ]
min U f(u) = 3 4 : = 7 4, * nx (y i ha i,uu > i) 2 i=1 Ø Regime of parameters:! #$ % # % Ø Ideal solution: ' satisfying '' ( = * has zero training error Ø other solution ' with zero training error but '' ( * Gradient descent with small initialization empirically converges to the ideal solution! [Gunasekar et al. 2017] Ø Compared to low-rank factorization (taking ' R / 1 ): the algorithm finds the correct rank automatically
test error (population risk) = Ef = km UU > k 2 F Ø! = 5, % = 5&! Ø Early stopping and stochasticity is not necessary Systematic empirical studies in [Gunasekar et al. 2017]
Theorem: [Li-M.-Zhang 17] With!"($% & ) observations, and initialization ( ) = + - and learning rate., at iteration / satisfying 0 log 5 / 0 0 1 6 1 56 Technicalities:, the generalization error is bounded by ( 9 ( 9 : < = & $+ Ø We assume < is well-conditioned Ø Theorem also holds when the measurements > 0,, > A satisfy B- restricted isometry property with B 1/ % ØThe runtime bound is non-trivial even with infinite samples
Gradient descent prefers low complexity solutions S r = {approximately rank-r solutions} := {U : r+1(u) apple } Non-generalizable global minima of training loss! $ generalizable global minima of training loss! "! # 0
More concrete analysis plan: Ø GD on population risk %& stays in! # Ø GD on & behaves similarly to that on %& in! # ref(u) rf(u) Ø Generalization is trivial in! # Ef(U) f(u), 8U 2 S r! $ Non-generalizable global minima of training loss generalizable global minima of training loss! "! # GD on %& 0
Ø Input dim = 100 ØGenerate labels with a network of hidden layer size! = 1 Ø Train with hidden layer size = 100
Ø WLOG, assume! = # # %, # = 1 Ø Decompose the iterate ( ) into: ( ) = # * ) % +, ) signal noise 7 :*; < Goals: show inductively Ø, ) -. 0 Ø * ) 1 Ø These imply ( ) # # %, )23, ) + 267 * )23 1 + 6 1 * ) * ) 267
Lemma 1:! "#$! " + 2() Ø Preparation: = U t (U t U t M)U t =(I (U t U t M))U t Small when * " is approx. low rank Ø Proof: E t =(I u? u?> )U t E t+1 = E t (I U t U t ) + small term b.c. (I u? u?> )M =0 GD on population risk reduces the error ke t+1 kappleke t k + small term
Ø! : entry-wise quadratic $ = & '!() ' *) Ø Almost equivalent to matrix sensing with rank-1 measurement: $ = ) ' *, = ** ', )) ' measurement matrix to recover Ø Only difference: unlike random measurement, ** ' doesn t satisfy restricted isometry property Ø Solution: throwing away a very small fraction of the data (adaptively) that devastate restricted isometry property
Ø Generalization error depends on initialization Initialization =! #
Initialization =! Ø Caveat: SGD or GD with large initialization can work with quadratic neural networks. (But the current theory requires small initialization.)
Ø Algo. analyzed: GD on Ø Algo. for comparison: projected GD on min U f(u) = nx (y i ha i,uu > i) 2 i=1 min g(z) = X n (y i ha i,zi) 2 Z 0 i=1
Ø Algorithms have an implicit regularization effect Open questions: Ø other matrix factorization based models Ø logistic loss [Gunasekar et al 18] Ø neural nets with other activation functions and loss (more in Nati s talk) Ø better understanding of algorithms for deep learning Ø which seems to be very helpful for fully understanding generalization Thank you!