00001
#include "factory.h"
00002
#include "integer.h"
00003
#include "filters.h"
00004
#include "hex.h"
00005
#include "randpool.h"
00006
#include "files.h"
00007
#include "trunhash.h"
00008
#include <iostream>
00009
#include <memory>
00010
00011 USING_NAMESPACE(CryptoPP)
00012 USING_NAMESPACE(std)
00013
00014
RandomPool & GlobalRNG();
00015
void RegisterFactories();
00016
00017 typedef std::map<std::string, std::string> TestData;
00018
00019 class TestFailure : public
Exception
00020 {
00021
public:
00022 TestFailure() : Exception(OTHER_ERROR,
"Validation test failed") {}
00023 };
00024
00025
static const TestData *s_currentTestData = NULL;
00026
00027
static void OutputTestData(
const TestData &v)
00028 {
00029
for (TestData::const_iterator i = v.begin(); i != v.end(); ++i)
00030 {
00031 cerr << i->first <<
": " << i->second << endl;
00032 }
00033 }
00034
00035
static void SignalTestFailure()
00036 {
00037 OutputTestData(*s_currentTestData);
00038
throw TestFailure();
00039 }
00040
00041
static void SignalTestError()
00042 {
00043 OutputTestData(*s_currentTestData);
00044
throw Exception(Exception::OTHER_ERROR,
"Unexpected error during validation test");
00045 }
00046
00047
class TestDataNameValuePairs :
public NameValuePairs
00048 {
00049
public:
00050 TestDataNameValuePairs(
const TestData &data) : m_data(data) {}
00051
00052
virtual bool GetVoidValue(
const char *name,
const std::type_info &valueType,
void *pValue)
const
00053
{
00054 TestData::const_iterator i = m_data.find(name);
00055
if (i == m_data.end())
00056
return false;
00057
00058
const std::string &value = i->second;
00059
00060
if (valueType ==
typeid(
int))
00061 *reinterpret_cast<int *>(pValue) = atoi(value.c_str());
00062
else if (valueType ==
typeid(
Integer))
00063 *reinterpret_cast<Integer *>(pValue) =
Integer((std::string(value) +
"h").c_str());
00064
else if (valueType ==
typeid(
ConstByteArrayParameter))
00065 {
00066 m_temp.resize(0);
00067
StringSource(value,
true,
new HexDecoder(
new StringSink(m_temp)));
00068 reinterpret_cast<ConstByteArrayParameter *>(pValue)->Assign((
const byte *)m_temp.data(), m_temp.size(),
true);
00069 }
00070
else if (valueType ==
typeid(
const byte *))
00071 {
00072 m_temp.resize(0);
00073
StringSource(value,
true,
new HexDecoder(
new StringSink(m_temp)));
00074 *reinterpret_cast<const byte * *>(pValue) = (
const byte *)m_temp.data();
00075 }
00076
else
00077
throw ValueTypeMismatch(name,
typeid(std::string), valueType);
00078
00079
return true;
00080 }
00081
00082
private:
00083
const TestData &m_data;
00084
mutable std::string m_temp;
00085 };
00086
00087
const std::string & GetRequiredDatum(
const TestData &data,
const char *name)
00088 {
00089 TestData::const_iterator i = data.find(name);
00090
if (i == data.end())
00091 SignalTestError();
00092
return i->second;
00093 }
00094
00095
void PutDecodedDatumInto(
const TestData &data,
const char *name,
BufferedTransformation &target)
00096 {
00097 std::string s1 = GetRequiredDatum(data, name), s2;
00098
00099
int repeat = 1;
00100
if (s1[0] ==
'r')
00101 {
00102 repeat = atoi(s1.c_str()+1);
00103 s1 = s1.substr(s1.find(
' ')+1);
00104 }
00105
00106
if (s1[0] ==
'\"')
00107 s2 = s1.substr(1, s1.find(
'\"', 1)-1);
00108
else if (s1.substr(0, 2) ==
"0x")
00109
StringSource(s1.substr(2),
true,
new HexDecoder(
new StringSink(s2)));
00110
else
00111
StringSource(s1,
true,
new HexDecoder(
new StringSink(s2)));
00112
00113
while (repeat--)
00114 target.
Put((
const byte *)s2.data(), s2.size());
00115 }
00116
00117 std::string GetDecodedDatum(
const TestData &data,
const char *name)
00118 {
00119 std::string s;
00120 PutDecodedDatumInto(data, name,
StringSink(s).Ref());
00121
return s;
00122 }
00123
00124
void TestKeyPairValidAndConsistent(
CryptoMaterial &pub,
const CryptoMaterial &priv)
00125 {
00126
if (!pub.
Validate(GlobalRNG(), 3))
00127 SignalTestFailure();
00128
if (!priv.
Validate(GlobalRNG(), 3))
00129 SignalTestFailure();
00130
00131
00132
00133
00134
00135
00136
00137
00138 }
00139
00140
void TestSignatureScheme(TestData &v)
00141 {
00142 std::string name = GetRequiredDatum(v,
"Name");
00143 std::string test = GetRequiredDatum(v,
"Test");
00144
00145 std::auto_ptr<PK_Signer> signer(
ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str()));
00146 std::auto_ptr<PK_Verifier> verifier(
ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str()));
00147
00148 TestDataNameValuePairs pairs(v);
00149 std::string keyFormat = GetRequiredDatum(v,
"KeyFormat");
00150
00151
if (keyFormat ==
"DER")
00152 verifier->AccessMaterial().Load(
StringStore(GetDecodedDatum(v,
"PublicKey")).Ref());
00153
else if (keyFormat ==
"Component")
00154 verifier->AccessMaterial().AssignFrom(pairs);
00155
00156
if (test ==
"Verify" || test ==
"NotVerify")
00157 {
00158
VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN);
00159 PutDecodedDatumInto(v,
"Signature", verifierFilter);
00160 PutDecodedDatumInto(v,
"Message", verifierFilter);
00161 verifierFilter.MessageEnd();
00162
if (verifierFilter.GetLastResult() == (test ==
"NotVerify"))
00163 SignalTestFailure();
00164 }
00165
else if (test ==
"PublicKeyValid")
00166 {
00167
if (!verifier->GetMaterial().Validate(GlobalRNG(), 3))
00168 SignalTestFailure();
00169 }
00170
else
00171
goto privateKeyTests;
00172
00173
return;
00174
00175 privateKeyTests:
00176
if (keyFormat ==
"DER")
00177 signer->AccessMaterial().Load(
StringStore(GetDecodedDatum(v,
"PrivateKey")).Ref());
00178
else if (keyFormat ==
"Component")
00179 signer->AccessMaterial().AssignFrom(pairs);
00180
00181
if (test ==
"KeyPairValidAndConsistent")
00182 {
00183 TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial());
00184 }
00185
else if (test ==
"Sign")
00186 {
00187
SignerFilter f(GlobalRNG(), *signer,
new HexEncoder(
new FileSink(cout)));
00188
StringSource ss(GetDecodedDatum(v,
"Message"),
true,
new Redirector(f));
00189 SignalTestFailure();
00190 }
00191
else if (test ==
"DeterministicSign")
00192 {
00193 SignalTestError();
00194 assert(
false);
00195 }
00196
else if (test ==
"RandomSign")
00197 {
00198 SignalTestError();
00199 assert(
false);
00200 }
00201
else if (test ==
"GenerateKey")
00202 {
00203 SignalTestError();
00204 assert(
false);
00205 }
00206
else
00207 {
00208 SignalTestError();
00209 assert(
false);
00210 }
00211 }
00212
00213
void TestAsymmetricCipher(TestData &v)
00214 {
00215 std::string name = GetRequiredDatum(v,
"Name");
00216 std::string test = GetRequiredDatum(v,
"Test");
00217
00218 std::auto_ptr<PK_Encryptor> encryptor(
ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str()));
00219 std::auto_ptr<PK_Decryptor> decryptor(
ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str()));
00220
00221 std::string keyFormat = GetRequiredDatum(v,
"KeyFormat");
00222
00223
if (keyFormat ==
"DER")
00224 {
00225 decryptor->AccessMaterial().Load(
StringStore(GetDecodedDatum(v,
"PrivateKey")).Ref());
00226 encryptor->AccessMaterial().Load(
StringStore(GetDecodedDatum(v,
"PublicKey")).Ref());
00227 }
00228
else if (keyFormat ==
"Component")
00229 {
00230 TestDataNameValuePairs pairs(v);
00231 decryptor->AccessMaterial().AssignFrom(pairs);
00232 encryptor->AccessMaterial().AssignFrom(pairs);
00233 }
00234
00235
if (test ==
"DecryptMatch")
00236 {
00237 std::string decrypted, expected = GetDecodedDatum(v,
"Plaintext");
00238
StringSource ss(GetDecodedDatum(v,
"Ciphertext"),
true,
new PK_DecryptorFilter(GlobalRNG(), *decryptor,
new StringSink(decrypted)));
00239
if (decrypted != expected)
00240 SignalTestFailure();
00241 }
00242
else if (test ==
"KeyPairValidAndConsistent")
00243 {
00244 TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial());
00245 }
00246
else
00247 {
00248 SignalTestError();
00249 assert(
false);
00250 }
00251 }
00252
00253
void TestSymmetricCipher(TestData &v)
00254 {
00255 std::string name = GetRequiredDatum(v,
"Name");
00256 std::string test = GetRequiredDatum(v,
"Test");
00257
00258 std::string key = GetDecodedDatum(v,
"Key");
00259 std::string ciphertext = GetDecodedDatum(v,
"Ciphertext");
00260 std::string plaintext = GetDecodedDatum(v,
"Plaintext");
00261
00262 TestDataNameValuePairs pairs(v);
00263
00264
if (test ==
"Encrypt")
00265 {
00266 std::auto_ptr<SymmetricCipher> encryptor(
ObjectFactoryRegistry<SymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str()));
00267 encryptor->SetKey((
const byte *)key.data(), key.size(), pairs);
00268 std::string encrypted;
00269
StringSource ss(plaintext,
true,
new StreamTransformationFilter(*encryptor,
new StringSink(encrypted), StreamTransformationFilter::NO_PADDING));
00270
if (encrypted != ciphertext)
00271 SignalTestFailure();
00272 }
00273
else if (test ==
"Decrypt")
00274 {
00275 std::auto_ptr<SymmetricCipher> decryptor(
ObjectFactoryRegistry<SymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str()));
00276 decryptor->SetKey((
const byte *)key.data(), key.size(), pairs);
00277 std::string decrypted;
00278
StringSource ss(ciphertext,
true,
new StreamTransformationFilter(*decryptor,
new StringSink(decrypted), StreamTransformationFilter::NO_PADDING));
00279
if (decrypted != plaintext)
00280 SignalTestFailure();
00281 }
00282
else
00283 {
00284 SignalTestError();
00285 assert(
false);
00286 }
00287 }
00288
00289
void TestDigestOrMAC(TestData &v,
bool testDigest)
00290 {
00291 std::string name = GetRequiredDatum(v,
"Name");
00292 std::string test = GetRequiredDatum(v,
"Test");
00293
00294 member_ptr<MessageAuthenticationCode> mac;
00295 member_ptr<HashTransformation> hash;
00296
HashTransformation *pHash = NULL;
00297
00298
if (testDigest)
00299 {
00300 hash.reset(
ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str()));
00301 pHash = hash.get();
00302 }
00303
else
00304 {
00305 mac.reset(
ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str()));
00306 pHash = mac.get();
00307 std::string key = GetDecodedDatum(v,
"Key");
00308 mac->SetKey((
const byte *)key.c_str(), key.size());
00309 }
00310
00311
if (test ==
"Verify" || test ==
"VerifyTruncated" || test ==
"NotVerify")
00312 {
00313
int digestSize = pHash->
DigestSize();
00314
if (test ==
"VerifyTruncated")
00315 digestSize = atoi(GetRequiredDatum(v,
"TruncatedSize").c_str());
00316
TruncatedHashModule thash(*pHash, digestSize);
00317
HashVerificationFilter verifierFilter(thash, NULL, HashVerificationFilter::HASH_AT_BEGIN);
00318 PutDecodedDatumInto(v,
"Digest", verifierFilter);
00319 PutDecodedDatumInto(v,
"Message", verifierFilter);
00320 verifierFilter.MessageEnd();
00321
if (verifierFilter.GetLastResult() == (test ==
"NotVerify"))
00322 SignalTestFailure();
00323 }
00324
else
00325 {
00326 SignalTestError();
00327 assert(
false);
00328 }
00329 }
00330
00331
bool GetField(std::istream &is, std::string &name, std::string &value)
00332 {
00333 name.resize(0);
00334 is >> name;
00335
if (name.empty())
00336
return false;
00337
00338
if (name[name.size()-1] !=
':')
00339 SignalTestError();
00340 name.erase(name.size()-1);
00341
00342
while (is.peek() ==
' ')
00343 is.ignore(1);
00344
00345
00346
char buffer[128];
00347 value.resize(0);
00348
bool continueLine;
00349
00350
do
00351 {
00352
do
00353 {
00354 is.get(buffer,
sizeof(buffer));
00355 value += buffer;
00356 }
00357
while (buffer[0] != 0);
00358 is.clear();
00359 is.ignore();
00360
00361
if (value[value.size()-1] ==
'\\')
00362 {
00363 value.resize(value.size()-1);
00364 continueLine =
true;
00365 }
00366
else
00367 continueLine =
false;
00368
00369 std::string::size_type i = value.find(
'#');
00370
if (i != std::string::npos)
00371 value.erase(i);
00372 }
00373
while (continueLine);
00374
00375
return true;
00376 }
00377
00378
void OutputPair(
const NameValuePairs &v,
const char *name)
00379 {
00380
Integer x;
00381
bool b = v.
GetValue(name, x);
00382 assert(b);
00383 cout << name <<
": \\\n ";
00384 x.Encode(
HexEncoder(
new FileSink(cout),
false, 64,
"\\\n ").Ref(), x.MinEncodedSize());
00385 cout << endl;
00386 }
00387
00388
void OutputNameValuePairs(
const NameValuePairs &v)
00389 {
00390 std::string names = v.
GetValueNames();
00391 string::size_type i = 0;
00392
while (i < names.size())
00393 {
00394 string::size_type j = names.find_first_of (
';', i);
00395
00396
if (j == string::npos)
00397
return;
00398
else
00399 {
00400 std::string name = names.substr(i, j-i);
00401
if (name.find(
':') == string::npos)
00402 OutputPair(v, name.c_str());
00403 }
00404
00405 i = j + 1;
00406 }
00407 }
00408
00409
void TestDataFile(
const std::string &filename,
unsigned int &totalTests,
unsigned int &failedTests)
00410 {
00411 std::ifstream file(filename.c_str());
00412 TestData v;
00413 s_currentTestData = &v;
00414 std::string name, value, lastAlgName;
00415
00416
while (file)
00417 {
00418
while (file.peek() ==
'#')
00419 file.ignore(INT_MAX,
'\n');
00420
00421
if (file.peek() ==
'\n')
00422 v.clear();
00423
00424
if (!GetField(file, name, value))
00425
break;
00426 v[name] = value;
00427
00428
if (name ==
"Test")
00429 {
00430
bool failed =
true;
00431 std::string algType = GetRequiredDatum(v,
"AlgorithmType");
00432
00433
if (lastAlgName != GetRequiredDatum(v,
"Name"))
00434 {
00435 lastAlgName = GetRequiredDatum(v,
"Name");
00436 cout <<
"\nTesting " << algType.c_str() <<
" algorithm " << lastAlgName.c_str() <<
".\n";
00437 }
00438
00439
try
00440 {
00441
if (algType ==
"Signature")
00442 TestSignatureScheme(v);
00443
else if (algType ==
"SymmetricCipher")
00444 TestSymmetricCipher(v);
00445
else if (algType ==
"AsymmetricCipher")
00446 TestAsymmetricCipher(v);
00447
else if (algType ==
"MessageDigest")
00448 TestDigestOrMAC(v,
true);
00449
else if (algType ==
"MAC")
00450 TestDigestOrMAC(v,
false);
00451
else if (algType ==
"FileList")
00452 TestDataFile(GetRequiredDatum(v,
"Test"), totalTests, failedTests);
00453
else
00454 SignalTestError();
00455 failed =
false;
00456 }
00457
catch (TestFailure &)
00458 {
00459 cout <<
"\nTest failed.\n";
00460 }
00461
catch (CryptoPP::Exception &e)
00462 {
00463 cout <<
"\nCryptoPP::Exception caught: " << e.what() << endl;
00464 }
00465
catch (std::exception &e)
00466 {
00467 cout <<
"\nstd::exception caught: " << e.what() << endl;
00468 }
00469
00470
if (failed)
00471 {
00472 cout <<
"Skipping to next test.\n";
00473 failedTests++;
00474 }
00475
else
00476 cout <<
"." << flush;
00477
00478 totalTests++;
00479 }
00480 }
00481 }
00482
00483
bool RunTestDataFile(
const char *filename)
00484 {
00485 RegisterFactories();
00486
unsigned int totalTests = 0, failedTests = 0;
00487 TestDataFile(filename, totalTests, failedTests);
00488 cout <<
"\nTests complete. Total tests = " << totalTests <<
". Failed tests = " << failedTests <<
".\n";
00489
if (failedTests != 0)
00490 cout <<
"SOME TESTS FAILED!\n";
00491
return failedTests == 0;
00492 }