* ebayes2_reml y1 y2 p11 p22 p12 [vars] [if ...], 
* [ noconstant
*   gen1() gen2()
*   theta1() theta2()
*   s11() s22() s12()
*   us11() us22() us12()
*   raw11() raw22() raw12()
* ]

* y1, y2 - the fixed effects
* p11, p22, p12 - the estimated variance of the first and second fe,
*  and the covariance between them
* vars - RHS vars in the regression used to estimate the true mean
*        of the fe (theta). can be left blank.
* noconstant - do not add a constant to the set of RHS vars when estimating the underlying
*              process
* gen1(), gen2() - variable to put empirical bayes adjusted fe
* theta1(), theta2() - value toward which the fe was attenuated (e.g. mu)
* s11(), s22(), s12() - underlying vcov matrix of the fe-DGP (conditional on covariates)
* us11(), us22(), us12() - underlying vcov matrix (unconditional on the covariates)
* raw11(), raw22(), raw12() - raw vcov matrix of the fe including measurement error

program define ebayes2_reml, sortpreserve

	syntax varlist(numeric min=5) [if], ///
		[ NOconstant ///
		gen1(name) gen2(name) ///
		theta1(name) theta2(name) ///
		s11(name) s22(name) s12(name) ///
		us11(name) us22(name) us12(name) ///
		raw11(name) raw22(name) raw12(name) ///
		]
	
	marksample touse
		
	gettoken y1 varlist: varlist
	gettoken y2 varlist: varlist
	gettoken p11 varlist: varlist
	gettoken p22 varlist: varlist
	gettoken p12 coeff: varlist
	
	foreach var in gen1 gen2 theta1 theta2 s11 s22 s12 us11 us22 us12 raw11 raw22 raw12 {
		if ("``var''"=="") {
			tempvar `var'
		}
		else {
			capture confirm variable ``var''
			if (!_rc) {
				display as error "variable ``var'' must not exist in data yet!"
				error 198
			}
		}
		
		qui gen ``var'' = .
	}
	
	* add constant term unless requested not to
	if ("`noconstant'"!="noconstant") {
		tempvar ones
		gen `ones' = 1
		local coeff "`coeff' `ones'"
	}
	
	* reduce sample to the set of y vars, x vars, measurement error vcovs that are nonmissing
	markout `touse' `y1' `y2' `p11' `p22' `p12' `coeff'
	
	* deal with collinearity in the variables
	forvalues yvar=1/2 {
		capture _rmdcoll `y`yvar'' `coeff' if `touse', nocons
		if (_rc==459) {
			display as error ///
				"y`yvar' variable perfectly collinear with at least 1 RHS variable. exiting"
			error 198
		}
		if (r(k_omitted) > 0) {
			display as error "collinearity in RHS variables. exiting"
			error 198
		}
	}
	
	* starting values for the sigma matrix come from a simple unweighted multivariate
	* regression less simple average of the measurement errors
	
	qui mvreg `y1' `y2' = `coeff' if `touse', nocons
	matrix SIG = e(Sigma)
	display "sigma matrix including meas error"
	matrix list SIG
	
	qui summ `p11'
	local s11_start = max(SIG[1,1] - r(mean),0)
	qui summ `p22'
	local s22_start = max(SIG[2,2] - r(mean),0)
	qui summ `p12'
	local s12_start = max(min(SIG[2,1] - r(mean),sqrt(`s11_start'*`s22_start')),-sqrt(`s11_start'*`s22_start'))
	
	display "starting values:"
	display "s11 `s11_start' s22 `s22_start' s12 `s12_start'"
	
	mata: reml( ///
		"`y1'","`y2'","`coeff'","`coeff'", ///
		"`p11'","`p22'","`p12'", ///
		"`s11'","`s22'","`s12'", ///
		"`us11'","`us22'","`us12'", ///
		"`raw11'","`raw22'","`raw12'", ///
		`s11_start',`s22_start',`s12_start', ///
		"`theta1'","`theta2'","`gen1'","`gen2'", ///
		"`touse'")

end


mata:

void reml(
	string scalar y1, string scalar y2,
	string scalar x1, string scalar x2,
	string scalar p11, string scalar p22, string scalar p12,
	string scalar s11, string scalar s22, string scalar s12,
	string scalar us11, string scalar us22, string scalar us12,
	string scalar raw11, string scalar raw22, string scalar raw12,
	real s11_start, real s22_start, real s12_start,
	string scalar theta1, string scalar theta2,
	string scalar gen1, string scalar gen2,
	string scalar tousename
) {
	// bring in the variables from stata in "wide format" (one observation for each
	// pair)
	
	// e.g. [ Y1_i Y2_i ] = [ X1_i X2_i ]

	// the y variables
	y_wide = .
	st_view(y_wide, . , (y1,y2), tousename)
	
	// the x variables
	// (x1 variables)
	x_left = .
	st_view(x_left, . , tokens(x1), tousename)
	// (x2 variables)
	x_right = .
	st_view(x_right, . , tokens(x2), tousename)

	// the p variables (measurement error vcov variables)
	p_wide = .
	st_view(p_wide, . , (p11,p22,p12), tousename)
		
	// the theta variables (underlying mean)
	theta_wide = .
	st_view(theta_wide, . , (theta1,theta2), tousename)
	
	// then gen variables (eb adjusted values)
	gen_wide = .
	st_view(gen_wide, . , (gen1,gen2), tousename)
	
	// convert all the variables to "long format" (two observations per pair)
	
	// e.g. [Y1_i \ Y2_i ] = [X1_i 0 \ 0 X2_i ]
	
	// make the long Y matrix
	Y = J(2*rows(y_wide),1,.)
	j = 1
	for (i=1; i<=rows(y_wide); i++) {
		Y[|j,1 \ j+1,1|] = y_wide[i,1] \ y_wide[i,2]
		j = j + 2
	}
	
	// make the long X matrix
	c_left = cols(x_left)
	c_right = cols(x_right)
	
	X = J(2*rows(x_left),c_left+c_right,.)
	j = 1
	for (i=1; i<=rows(x_left); i++) {
		X[|j,1 \ j+1,c_left+c_right|] = x_left[i,.], J(1,c_right,0) \ J(1,c_left,0), x_right[i,.]
		j = j + 2
	}
	
	// make the measurement error (P) matrix
	// the block diagonal matrix with the measurement error 2x2 blocks on the diagonal
	// e.g. P_i = [ P11_i P12_i \ P12_i P22_i ]
	// P = diag( P_1 .. P_N )
	P = J(rows(Y),rows(Y),0)
	j = 1
	for (i=1; i<=rows(p_wide); i++) {
		P[|j,j \ j+1,j+1|] = p_wide[i,1],p_wide[i,3] \ p_wide[i,3],p_wide[i,2]
		j = j + 2
	}
	
	// make the "beta" matrix
	B = J(cols(X),1,.)
	
	// initialize the optimizer
	evaluator = optimize_init()
	optimize_init_evaluator(evaluator,&reml_evalf())
	
	// starting values for the sigma parameters
	optimize_init_params(evaluator,(s11_start,s22_start,s12_start))
	// pass in Y X and P
	optimize_init_argument(evaluator,1,Y)
	optimize_init_argument(evaluator,2,X)
	optimize_init_argument(evaluator,3,P)
	optimize_init_argument(evaluator,4,B)
	// evaluator returns a vector
	optimize_init_evaluatortype(evaluator,"v0")
	
	// calculate the underlying vcov matrix
	// this also returns B~ the estimated coefficients on the x vars
	s = optimize(evaluator)
	
	// underlying mean of the Y
	theta = X*B
	
	// calculate the EB-adjusted y1 and y2
	// calculate the vcov of the underlying process + fixed component (e.g. theta=XB)
	
	// now we work in "block form" e.g. at the observation level
	
	// calculate vcov in two ways
	// "optimal" which uses matrix weights. yet this seems to not necessarily
	// return positive semidefinite vcov matrices!
	// "alternative" which uses weights matching the univariate case for variance
	// terms (e.g. 1/[underlying var + measurement variance])
	// and geometric average of the two weights for covariance terms
	// (e.g. sqrt(wt_1*wt_2))

	// optimal weights
	
	// average of (y1 \ y2)
	y_bar_opt = J(2,1,0)
	// average of (e1 \ e2) (where e = y - theta)
	e_bar_opt = J(2,1,0)
	
	// variance covariance matrix for (y1 \ y2)
	y_vcov_opt = J(2,2,0)
	
	// variance covariance matrix for (e1 \ e2 )
	// this embodies both s and p, and is some weighted average of it
	// though the degrees of freedom correction is probably off
	// (this is computed as a sanity check, and is not used)
	e_vcov_opt = J(2,2,0)
	
	// average of the underlying vcov + measurement error vcov
	// computed using the same weights that we use to make y_vcov
	sp_bar_opt = J(2,2,0)
	
	// rolling sum of the weights (so we can premultiply it off)
	sum_wt_opt = J(2,2,0)

	// the underlying vcov in block form
	s_block = s[1,1], s[1,3] \ s[1,3], s[1,2]
	
	// alternatively weighted averages and variances
	
	y_bar_alt = J(2,1,0)
	e_bar_alt = J(2,1,0)
	y_vcov_alt = J(2,2,0)
	e_vcov_alt = J(2,2,0)
	sp_bar_alt = J(2,2,0)
	sum_wt_alt = J(2,2,0)
	
	// walk over each pair (block)
	j = 1
	for (i=1; i<=rows(theta_wide); i++) {
		
		// block of underlying means
		theta_block = theta[|j,1 \ j+1,1|]
		// block of observed FEs
		y_block = Y[|j,1 \ j+1,1|]
		// block of deviations
		e_block = y_block - theta_block
		
		// block of measurement error vcov for this pair
		p_block = p_wide[i,1],p_wide[i,3] \ p_wide[i,3],p_wide[i,2]
		
		// inverted vcov's
		s_block_inv = invsym(s_block)
		p_block_inv = invsym(p_block)
		
		// block of eb adjusted fe's
		// this is the "posterior hyperparameter" of a multivariate normal with known
		// covariance matrix, given at http://en.wikipedia.org/wiki/Conjugate_prior
		gen_block = invsym(s_block_inv+p_block_inv)*(s_block_inv*theta_block+p_block_inv*y_block)
		
		// push the underlying means into their stata variables
		theta_wide[i,.] = (theta[j,1], theta[j+1,1])
		// push the eb-adjusted fe's into their stata variables
		gen_wide[i,.] = gen_block'
		
		// the weight for this pair when calculating the weighted average
		wt_opt = invsym(s_block + p_block)
		// rolling sum of the weight
		sum_wt_opt = sum_wt_opt + wt_opt
		
		// add weighted y and e to the average
		y_bar_opt = y_bar_opt + cross(wt_opt,y_block)
		e_bar_opt = e_bar_opt + cross(wt_opt,e_block)
		
		// add weighted yy' and ee' to the average
		y_vcov_opt = y_vcov_opt + cross(wt_opt,cross(y_block',y_block'))
		e_vcov_opt = e_vcov_opt + cross(wt_opt,cross(e_block',e_block'))
		
		// add weighted error vcov to the average
		sp_bar_opt = sp_bar_opt + cross(wt_opt,(s_block+p_block))
		
		// alternative weighting!
		
		// // weight for diagonal is inverted sum of measurement error + underlying variance
		// wt_alt = diag(diagonal(s_block+p_block)):^-1
		// // weight for cross term is geometric mean of weights for the diagonals
		// wt_alt[1,2] = sqrt(wt_alt[1,1]*wt_alt[2,2])
		// wt_alt[2,1] = wt_alt[1,2]
		
		wt_alt = J(2,2, 1/( sqrt( (s_block[1,1]+p_block[1,1] )*(s_block[2,2]+p_block[2,2]) )))
		
		// rolling sum of the weight
		sum_wt_alt = sum_wt_alt + wt_alt
		
		// add weighted y and e to the average
		y_bar_alt = y_bar_alt + diagonal(wt_alt):*y_block
		e_bar_alt = e_bar_alt + diagonal(wt_alt):*e_block
		
		// add weighted yy' and ee' to the average
		// note the elementwise weight
		y_vcov_alt = y_vcov_alt + wt_alt:*cross(y_block',y_block')
		e_vcov_alt = e_vcov_alt + wt_alt:*cross(e_block',e_block')
		
		// add weighted error vcov to the average
		sp_bar_alt = sp_bar_alt + wt_alt:*(s_block+p_block)
						
		j = j + 2
	}
	
	// the inverted summed weights -- will normalize all weighted sums by this
	sum_wt_inv_opt = invsym(sum_wt_opt)
	
	// fix the averages
	y_bar_opt = cross(sum_wt_inv_opt,y_bar_opt)	
	e_bar_opt = cross(sum_wt_inv_opt,e_bar_opt)
	
	// fix the vcovs (and remove the first moment terms from them)
	y_vcov_opt = cross(sum_wt_inv_opt,y_vcov_opt)-cross(y_bar_opt',y_bar_opt')
	e_vcov_opt = cross(sum_wt_inv_opt,e_vcov_opt)-cross(e_bar_opt',e_bar_opt')
	// fix the averaged vcov
	sp_bar_opt = cross(sum_wt_inv_opt,sp_bar_opt)
	
	// force vcovs to be symmetric
	y_vcov_opt = (1/2)*(y_vcov_opt + y_vcov_opt')
	e_vcov_opt = (1/2)*(e_vcov_opt + e_vcov_opt')
	sp_bar_opt = (1/2)*(sp_bar_opt + sp_bar_opt')
	
	// implied vcov of theta (xb)
	// don't use vcov(xb) directly because there is a degrees of freedom correction
	// needed. instead, take vcov(y) - vcov(s+p)
	theta_vcov_opt = y_vcov_opt - sp_bar_opt
	
	// underlying vcov of the y's, unconditional on the x's
	// this equals vcov(xb) + underlying vcov
	us_vcov_opt = theta_vcov_opt + s_block
	
	// alternative weights!
	
	// the inverted summed weights -- will normalize all weighted sums by this
	sum_wt_inv_alt = sum_wt_alt:^-1
	
	// fix the averages
	y_bar_alt = diagonal(sum_wt_inv_alt):*y_bar_alt
	e_bar_alt = diagonal(sum_wt_inv_alt):*e_bar_alt
	
	// fix the vcovs (and remove the first moment terms from them)
	y_vcov_alt = sum_wt_inv_alt:*y_vcov_alt - cross(y_bar_alt',y_bar_alt')
	e_vcov_alt = sum_wt_inv_alt:*e_vcov_alt - cross(e_bar_alt',e_bar_alt')
	// fix the averaged vcov
	sp_bar_alt = sum_wt_inv_alt:*sp_bar_alt

	// implied vcov of theta (xb)
	// don't use vcov(xb) directly because there is a degrees of freedom correction
	// needed. instead, take vcov(y) - vcov(s+p)
	theta_vcov_alt = y_vcov_alt - sp_bar_alt

	// underlying vcov of the y's, unconditional on the x's
	// this equals vcov(xb) + underlying vcov
	us_vcov_alt = theta_vcov_alt + s_block
		
	// copy the vcovs into the stata variables
	// we'll use the ALTERNATIVE weighted vcovs
	
	// first open views

	// the s variables (underlying vcov)
	s_wide = .
	st_view(s_wide, . , (s11,s22,s12), tousename)
	
	// the us variables (vcov unconditional on the X's)
	us_wide = .
	st_view(us_wide, . , (us11,us22,us12), tousename)
	
	// the raw variables (vcov of the y's)
	st_view(raw_wide, . , (raw11,raw22,raw12), tousename)

	s_wide[.,.] = J(rows(s_wide),1, ( s_block[1,1], s_block[2,2], s_block[1,2] ) )
	us_wide[.,.] = J(rows(us_wide),1, ( us_vcov_alt[1,1], us_vcov_alt[2,2], us_vcov_alt[1,2] ) )
	raw_wide[.,.] = J(rows(raw_wide),1, ( y_vcov_alt[1,1], y_vcov_alt[2,2], y_vcov_alt[1,2] ) )
	
	printf("vcov of the underlying process\n")
	matlist(s_block)
	printf("underlying correlation\n")
	matlist( ( s_block[1,2]/sqrt(s_block[1,1]*s_block[2,2]) ) )
	
	printf("*** optimal weighting ***\n")
	printf("vcov of the x\n")
	matlist(theta_vcov_opt)
	printf("vcov of the underlying process (unconditional on the x)\n")
	matlist(us_vcov_opt)
	printf("underlying correlation (unconditional on the x)\n")
	matlist( ( us_vcov_opt[1,2]/sqrt(us_vcov_opt[1,1]*us_vcov_opt[2,2]) ) )
	printf("mean of the fe\n")
	matlist(y_bar_opt)
	printf("raw vcov of the fe\n")
	matlist(y_vcov_opt)
	printf("mean of the e\n")
	matlist(e_bar_opt)
	
	printf("*** alternative weighting ***\n")
	printf("vcov of the x\n")
	matlist(theta_vcov_alt)
	printf("vcov of the underlying process (unconditional on the x)\n")
	matlist(us_vcov_alt)
	printf("underlying correlation (unconditional on the x)\n")
	matlist( ( us_vcov_alt[1,2]/sqrt(us_vcov_alt[1,1]*us_vcov_alt[2,2]) ) )
	printf("mean of the fe\n")
	matlist(y_bar_alt)
	printf("raw vcov of the fe\n")
	matlist(y_vcov_alt)
	printf("mean of the e\n")
	matlist(e_bar_alt)
	
}

function reml_evalf(todo,s,Y,X,P,B,lnf,g,H) {

	s_block = s[1,1], s[1,3] \ s[1,3], s[1,2]
	
	// the S + P matrix (underlying var plus meas error)
	SP = J(rows(Y),rows(Y),0)
	// the (S+P)^-1 matrix (simply the inverted diagonal blocks of the sigma matrix)
	SP_inv = J(rows(Y),rows(Y),0)
	// the logged determinant of the S+P matrix (simply the logged determinants
	// of the diagonal blocks)
	logdet = 0
	
	// construct the S+P matrix

	// the counter of the row/column to place the top left of the block into
	j = 1
	// loop over the blocks
	for (i=1; i<=(rows(SP)/2); i++) {
		tmp_block = P[|j,j \ j+1,j+1|] + s_block
		SP[|j,j \ j+1,j+1|] = tmp_block
		SP_inv[|j,j \ j+1,j+1|] = invsym(tmp_block)
		logdet = logdet + ln(det(tmp_block))
		j = j + 2
	}
	
	// X'(SP^-1)X
	XSPX = cross(X,cross(SP_inv,X))
	// ( X'(SP^-1)X )^-1
	XSPX_inv = invsym(XSPX)
	// X'(SP^-1)Y
	XSPY = cross(X,cross(SP_inv,Y))
	
	// estimated Beta = ( X'(SP^-1)X )^-1 * X'(SP^-1)Y
	B = cross(XSPX_inv, XSPY)
	
	// the error deviations: Y - X*B~
	// where B~ = ( X'(SP^-1)X )^-1 * X'(SP^-1)Y
	Y_dev = Y - X*cross(XSPX_inv,XSPY)
	
	// matrix determinant equals the product of the eigenvalues
	// thus ln(det()) equals the sum of the logged eigenvalues
	// so while the likelihood contains ln(det(XSPX)) , 
	// the following is equivalent and seems to work even when det(XSPX) evaluates to 0
	sumlogeig = sum(ln(symeigenvalues(XSPX)))
	
	//matlist((logdet,sumlogeig,ln(det(XSPX)),cross(Y_dev,cross(SP_inv,Y_dev))))
	
	// build the likelihood
	// this is the REML given on page 32 of
	// http://www.biostat.umn.edu/~xianghua/8452/note/03GLS.pdf
	
	lnf = -(1/2)*logdet
	
	// the reml part is right below
	// not using the following because it doesn't compute sometimes
	// lnf = lnf - (1/2)*ln(det(XSX))
	// use the summ of the log of the eigenvalues instead
	lnf = lnf - (1/2)*sumlogeig
	
	lnf = lnf - (1/2)*cross(Y_dev,cross(SP_inv,Y_dev))
	
}

void matlist(
    real matrix X,
    | string scalar fmt
    )
{
    real scalar     i, j, wd, rw, cw
    string scalar   sfmt

    if (fmt=="") fmt = "%g"
    wd = strlen(sprintf(fmt,-1/3))

    if (length(X)==0) return

    rw = trunc(log10(rows(X))) + 1
    cw = trunc(log10(cols(X))) + 1
    wd = max((cw,wd)) + 2
    sfmt = "%"+strofreal(wd)+"s"

    printf("{txt}"+(2+rw+1+1)*" ")
    for (j=1;j<=cols(X);j++) {
        printf(sfmt+" ", sprintf("%g", j))
    }
    printf("  \n")
    printf((2+rw+1)*" " + "{c TLC}{hline " +
        strofreal((wd+1)*cols(X)+1) + "}{c TRC}\n")
    for (i=1;i<=rows(X);i++) {
        printf("{txt}  %"+strofreal(rw)+"s {c |}{res}", sprintf("%g", i))
        for (j=1;j<=cols(X);j++) {
            printf(sfmt+" ",sprintf(fmt, X[i,j]))
        }
        printf(" {txt}{c |}\n")
    }
    printf((2+rw+1)*" " + "{c BLC}{hline " +
        strofreal((wd+1)*cols(X)+1) + "}{c BRC}\n")
}


end


exit
