1/* -*- Mode: C; tab-width: 4 -*-
2 *
3 * Copyright (c) 2002-2013 Apple Computer, Inc. All rights reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "Secret.h"
19#include <stdarg.h>
20#include <stddef.h>
21#include <stdio.h>
22#include <stdlib.h>
23#include <string.h>
24#include <winsock2.h>
25#include <ws2tcpip.h>
26#include <windows.h>
27#include <process.h>
28#include <ntsecapi.h>
29#include <lm.h>
30#include "DebugServices.h"
31
32
33mDNSlocal OSStatus MakeLsaStringFromUTF8String( PLSA_UNICODE_STRING output, const char * input );
34mDNSlocal OSStatus MakeUTF8StringFromLsaString( char * output, size_t len, PLSA_UNICODE_STRING input );
35
36
37mDNSBool
38LsaGetSecret( const char * inDomain, char * outDomain, unsigned outDomainSize, char * outKey, unsigned outKeySize, char * outSecret, unsigned outSecretSize )
39{
40	PLSA_UNICODE_STRING		domainLSA;
41	PLSA_UNICODE_STRING		keyLSA;
42	PLSA_UNICODE_STRING		secretLSA;
43	size_t					i;
44	size_t					dlen;
45	LSA_OBJECT_ATTRIBUTES	attrs;
46	LSA_HANDLE				handle = NULL;
47	NTSTATUS				res;
48	OSStatus				err;
49
50	check( inDomain );
51	check( outDomain );
52	check( outKey );
53	check( outSecret );
54
55	// Initialize
56
57	domainLSA	= NULL;
58	keyLSA		= NULL;
59	secretLSA	= NULL;
60
61	// Make sure we have enough space to add trailing dot
62
63	dlen = strlen( inDomain );
64	err = strcpy_s( outDomain, outDomainSize - 2, inDomain );
65	require_noerr( err, exit );
66
67	// If there isn't a trailing dot, add one because the mDNSResponder
68	// presents names with the trailing dot.
69
70	if ( outDomain[ dlen - 1 ] != '.' )
71	{
72		outDomain[ dlen++ ] = '.';
73		outDomain[ dlen ] = '\0';
74	}
75
76	// Canonicalize name by converting to lower case (keychain and some name servers are case sensitive)
77
78	for ( i = 0; i < dlen; i++ )
79	{
80		outDomain[i] = (char) tolower( outDomain[i] );  // canonicalize -> lower case
81	}
82
83	// attrs are reserved, so initialize to zeroes.
84
85	ZeroMemory( &attrs, sizeof( attrs ) );
86
87	// Get a handle to the Policy object on the local system
88
89	res = LsaOpenPolicy( NULL, &attrs, POLICY_GET_PRIVATE_INFORMATION, &handle );
90	err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
91	require_noerr( err, exit );
92
93	// Get the encrypted data
94
95	domainLSA = ( PLSA_UNICODE_STRING ) malloc( sizeof( LSA_UNICODE_STRING ) );
96	require_action( domainLSA != NULL, exit, err = mStatus_NoMemoryErr );
97	err = MakeLsaStringFromUTF8String( domainLSA, outDomain );
98	require_noerr( err, exit );
99
100	// Retrieve the key
101
102	res = LsaRetrievePrivateData( handle, domainLSA, &keyLSA );
103	err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
104	require_noerr_quiet( err, exit );
105
106	// <rdar://problem/4192119> Lsa secrets use a flat naming space.  Therefore, we will prepend "$" to the keyname to
107	// make sure it doesn't conflict with a zone name.
108	// Strip off the "$" prefix.
109
110	err = MakeUTF8StringFromLsaString( outKey, outKeySize, keyLSA );
111	require_noerr( err, exit );
112	require_action( outKey[0] == '$', exit, err = kUnknownErr );
113	memcpy( outKey, outKey + 1, strlen( outKey ) );
114
115	// Retrieve the secret
116
117	res = LsaRetrievePrivateData( handle, keyLSA, &secretLSA );
118	err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
119	require_noerr_quiet( err, exit );
120
121	// Convert the secret to UTF8 string
122
123	err = MakeUTF8StringFromLsaString( outSecret, outSecretSize, secretLSA );
124	require_noerr( err, exit );
125
126exit:
127
128	if ( domainLSA != NULL )
129	{
130		if ( domainLSA->Buffer != NULL )
131		{
132			free( domainLSA->Buffer );
133		}
134
135		free( domainLSA );
136	}
137
138	if ( keyLSA != NULL )
139	{
140		LsaFreeMemory( keyLSA );
141	}
142
143	if ( secretLSA != NULL )
144	{
145		LsaFreeMemory( secretLSA );
146	}
147
148	if ( handle )
149	{
150		LsaClose( handle );
151		handle = NULL;
152	}
153
154	return ( !err ) ? TRUE : FALSE;
155}
156
157
158mDNSBool
159LsaSetSecret( const char * inDomain, const char * inKey, const char * inSecret )
160{
161	size_t					inDomainLength;
162	size_t					inKeyLength;
163	char					domain[ 1024 ];
164	char					key[ 1024 ];
165	LSA_OBJECT_ATTRIBUTES	attrs;
166	LSA_HANDLE				handle = NULL;
167	NTSTATUS				res;
168	LSA_UNICODE_STRING		lucZoneName;
169	LSA_UNICODE_STRING		lucKeyName;
170	LSA_UNICODE_STRING		lucSecretName;
171	BOOL					ok = TRUE;
172	OSStatus				err;
173
174	require_action( inDomain != NULL, exit, ok = FALSE );
175	require_action( inKey != NULL, exit, ok = FALSE );
176	require_action( inSecret != NULL, exit, ok = FALSE );
177
178	// If there isn't a trailing dot, add one because the mDNSResponder
179	// presents names with the trailing dot.
180
181	ZeroMemory( domain, sizeof( domain ) );
182	inDomainLength = strlen( inDomain );
183	require_action( inDomainLength > 0, exit, ok = FALSE );
184	err = strcpy_s( domain, sizeof( domain ) - 2, inDomain );
185	require_action( !err, exit, ok = FALSE );
186
187	if ( domain[ inDomainLength - 1 ] != '.' )
188	{
189		domain[ inDomainLength++ ] = '.';
190		domain[ inDomainLength ] = '\0';
191	}
192
193	// <rdar://problem/4192119>
194	//
195	// Prepend "$" to the key name, so that there will
196	// be no conflict between the zone name and the key
197	// name
198
199	ZeroMemory( key, sizeof( key ) );
200	inKeyLength = strlen( inKey );
201	require_action( inKeyLength > 0 , exit, ok = FALSE );
202	key[ 0 ] = '$';
203	err = strcpy_s( key + 1, sizeof( key ) - 3, inKey );
204	require_action( !err, exit, ok = FALSE );
205	inKeyLength++;
206
207	if ( key[ inKeyLength - 1 ] != '.' )
208	{
209		key[ inKeyLength++ ] = '.';
210		key[ inKeyLength ] = '\0';
211	}
212
213	// attrs are reserved, so initialize to zeroes.
214
215	ZeroMemory( &attrs, sizeof( attrs ) );
216
217	// Get a handle to the Policy object on the local system
218
219	res = LsaOpenPolicy( NULL, &attrs, POLICY_ALL_ACCESS, &handle );
220	err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
221	require_noerr( err, exit );
222
223	// Intializing PLSA_UNICODE_STRING structures
224
225	err = MakeLsaStringFromUTF8String( &lucZoneName, domain );
226	require_noerr( err, exit );
227
228	err = MakeLsaStringFromUTF8String( &lucKeyName, key );
229	require_noerr( err, exit );
230
231	err = MakeLsaStringFromUTF8String( &lucSecretName, inSecret );
232	require_noerr( err, exit );
233
234	// Store the private data.
235
236	res = LsaStorePrivateData( handle, &lucZoneName, &lucKeyName );
237	err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
238	require_noerr( err, exit );
239
240	res = LsaStorePrivateData( handle, &lucKeyName, &lucSecretName );
241	err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
242	require_noerr( err, exit );
243
244exit:
245
246	if ( handle )
247	{
248		LsaClose( handle );
249		handle = NULL;
250	}
251
252	return ok;
253}
254
255
256//===========================================================================================================================
257//	MakeLsaStringFromUTF8String
258//===========================================================================================================================
259
260mDNSlocal OSStatus
261MakeLsaStringFromUTF8String( PLSA_UNICODE_STRING output, const char * input )
262{
263	int			size;
264	OSStatus	err;
265
266	check( input );
267	check( output );
268
269	output->Buffer = NULL;
270
271	size = MultiByteToWideChar( CP_UTF8, 0, input, -1, NULL, 0 );
272	err = translate_errno( size > 0, GetLastError(), kUnknownErr );
273	require_noerr( err, exit );
274
275	output->Length = (USHORT)( size * sizeof( wchar_t ) );
276	output->Buffer = (PWCHAR) malloc( output->Length );
277	require_action( output->Buffer, exit, err = mStatus_NoMemoryErr );
278	size = MultiByteToWideChar( CP_UTF8, 0, input, -1, output->Buffer, size );
279	err = translate_errno( size > 0, GetLastError(), kUnknownErr );
280	require_noerr( err, exit );
281
282	// We're going to subtrace one wchar_t from the size, because we didn't
283	// include it when we encoded the string
284
285	output->MaximumLength = output->Length;
286	output->Length		-= sizeof( wchar_t );
287
288exit:
289
290	if ( err && output->Buffer )
291	{
292		free( output->Buffer );
293		output->Buffer = NULL;
294	}
295
296	return( err );
297}
298
299
300
301//===========================================================================================================================
302//	MakeUTF8StringFromLsaString
303//===========================================================================================================================
304
305mDNSlocal OSStatus
306MakeUTF8StringFromLsaString( char * output, size_t len, PLSA_UNICODE_STRING input )
307{
308	size_t		size;
309	OSStatus	err = kNoErr;
310
311	// The Length field of this structure holds the number of bytes,
312	// but WideCharToMultiByte expects the number of wchar_t's. So
313	// we divide by sizeof(wchar_t) to get the correct number.
314
315	size = (size_t) WideCharToMultiByte(CP_UTF8, 0, input->Buffer, ( input->Length / sizeof( wchar_t ) ), NULL, 0, NULL, NULL);
316	err = translate_errno( size != 0, GetLastError(), kUnknownErr );
317	require_noerr( err, exit );
318
319	// Ensure that we have enough space (Add one for trailing '\0')
320
321	require_action( ( size + 1 ) <= len, exit, err = mStatus_NoMemoryErr );
322
323	// Convert the string
324
325	size = (size_t) WideCharToMultiByte( CP_UTF8, 0, input->Buffer, ( input->Length / sizeof( wchar_t ) ), output, (int) size, NULL, NULL);
326	err = translate_errno( size != 0, GetLastError(), kUnknownErr );
327	require_noerr( err, exit );
328
329	// have to add the trailing 0 because WideCharToMultiByte doesn't do it,
330	// although it does return the correct size
331
332	output[size] = '\0';
333
334exit:
335
336	return err;
337}
338
339